diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5b16903b8d533c947a250c5493504854a294ac9e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,118 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +.github/workflows/logo.gif filter=lfs diff=lfs merge=lfs -text +assets/egg.mp4 filter=lfs diff=lfs merge=lfs -text +examples/Comp-Attn.pdf filter=lfs diff=lfs merge=lfs -text +examples/InstanceV.pdf filter=lfs diff=lfs merge=lfs -text +examples/wanvideo/model_training/egg_statemachine_dataset/egg_4fps_640x320.mp4 filter=lfs diff=lfs merge=lfs -text +examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_1280x720.mp4 filter=lfs diff=lfs merge=lfs -text +examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_640x320.mp4 filter=lfs diff=lfs merge=lfs -text +examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_NonexNone.mp4 filter=lfs diff=lfs merge=lfs -text +examples/wanvideo/model_training/egg_statemachine_dataset/egg_8fps_448x256.mp4 filter=lfs diff=lfs merge=lfs -text +output/_smoke_352x640/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/_smoke_352x640/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text +output/_smoke_352x640/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/_smoke_352x640/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/_smoke_352x640/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-inside.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-fron.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-next-to.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-inside.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-a-ca.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-sm.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_352x640_stable/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_batch10/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-mc-lora_epoch4_1/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-statemachine-egg/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-statemachine-egg_cooked2raw/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-statemachine-egg_moveup_long/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-statemachine-egg_moveup_long20/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text +output/wan2.1-1.3b-statemachine-egg_moveup_long20_promptclean/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/boat_seagull_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/deer_approach_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/dog_running.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/dog_running_baseline.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/four_people_talking.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/four_pigeons_orbit_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/multi_instances_animals.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/single_car_sweep.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/three_diagonal_motion.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/two_crossing_athletes.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/two_people_talking.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/two_scooters_crossing_20260105_112906.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/two_scooters_crossing_20260105_113836.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/two_scooters_crossing_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev/two_students_drone_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_00.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_01.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_02.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_03.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_04.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_05.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_06.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_07.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_08.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_09.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_10.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_11.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_12.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_13.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_14.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_15.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_16.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_17.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_18.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev-new/case_19.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev_iground_infer_20260107_024437.mp4 filter=lfs diff=lfs merge=lfs -text +outputs/instancev_iground_test.mp4 filter=lfs diff=lfs merge=lfs -text +video.mp4 filter=lfs diff=lfs merge=lfs -text +video_1_Wan2.1-T2V-1.3B-LoRA2.mp4 filter=lfs diff=lfs merge=lfs -text +video_1_Wan2.1-T2V-1.3B-LoRA_epoch1.mp4 filter=lfs diff=lfs merge=lfs -text +video_1_Wan2.1-T2V-1.3B.mp4 filter=lfs diff=lfs merge=lfs -text +video_1_Wan2.1-T2V-1.3B_LoRA.mp4 filter=lfs diff=lfs merge=lfs -text +video_comp_attn_pipeline[[:space:]]copy.mp4 filter=lfs diff=lfs merge=lfs -text +video_comp_attn_pipeline.mp4 filter=lfs diff=lfs merge=lfs -text +wandb/run-20251211_101851-syoqkmhy/run-syoqkmhy.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20251211_172331-jxaicuod/run-jxaicuod.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20251225_172459-gjtz0um5/run-gjtz0um5.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20251225_214534-3dh8lbav/run-3dh8lbav.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20251229_100816-zirij84a/run-zirij84a.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260102_054910-38oaloji/run-38oaloji.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260102_104929-zd02vtce/run-zd02vtce.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260102_162705-mr7vgtqn/run-mr7vgtqn.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260103_090415-36yjbun5/run-36yjbun5.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260103_115016-kurow4tk/run-kurow4tk.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260106_030539-rupbhtts/run-rupbhtts.wandb filter=lfs diff=lfs merge=lfs -text +wandb/run-20260110_110203-bl4gd6wi/run-bl4gd6wi.wandb filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/logo.gif b/.github/workflows/logo.gif new file mode 100644 index 0000000000000000000000000000000000000000..ef5717efc17bbb2018a37a530a9d7a09e86277d9 --- /dev/null +++ b/.github/workflows/logo.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36a7627b7f0f0a508ec64aba72e5d95d38dfe7958bd8cf42d2a63f6ac2641529 +size 149067 diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f31e6bb7bf1a10f9394f9a395c093aa37eba4855 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,29 @@ +name: release + +on: + push: + tags: + - 'v**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-publish + cancel-in-progress: true + +jobs: + build-n-publish: + runs-on: ubuntu-20.04 + #if: startsWith(github.event.ref, 'refs/tags') + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: '3.10' + - name: Install wheel + run: pip install wheel==0.44.0 && pip install -r requirements.txt + - name: Build DiffSynth + run: python setup.py sdist bdist_wheel + - name: Publish package to PyPI + run: | + pip install twine + twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6fd0d8e14982668f34621ebbbe5836978842d29e --- /dev/null +++ b/.gitignore @@ -0,0 +1,175 @@ +/data +/models +/scripts +/diffusers +*.pkl +*.safetensors +*.pth +*.ckpt +*.pt +*.bin +*.DS_Store +*.msc +*.mv +log*.txt + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/.home/.modelscope/credentials/session b/.home/.modelscope/credentials/session new file mode 100644 index 0000000000000000000000000000000000000000..ed4aa06347fbae803af8392fe15260d03c67e064 --- /dev/null +++ b/.home/.modelscope/credentials/session @@ -0,0 +1 @@ +13921be3c1924b38a5a21db02dce6b94 \ No newline at end of file diff --git a/.swp b/.swp new file mode 100644 index 0000000000000000000000000000000000000000..f1e005c3f5acc166d0b706ee67c8d4f08dad0a9d Binary files /dev/null and b/.swp differ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0e5a49f5f70af9e9d37278f72315d1b1afd34895 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2023] [Zhongjie Duan] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1250216a095ae2dfa38d40c5123a118149d58bdf --- /dev/null +++ b/README.md @@ -0,0 +1,783 @@ +# DiffSynth-Studio + + modelscope%2FDiffSynth-Studio | Trendshift

+ +[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/) +[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE) +[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues) +[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/) +[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/) + +[切换到中文版](./README_zh.md) + +## Introduction + +Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology! + +DiffSynth currently includes two open-source projects: +* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, targeting academia, and providing cutting-edge model capability support. +* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, targeting industry, and providing higher computational performance and more stable features. + +[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core engines of the ModelScope AIGC zone. Welcome to experience our carefully crafted productized features: + +* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home +* ModelScope Civision (for global users): https://modelscope.ai/civision/home + +> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md) + +We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you. + +## Update History + +> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. + +> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. + +- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. + +- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online + - [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated + - [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously + - New model support + - Z-Image Turbo: [Model](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo), [Documentation](/docs/en/Model_Details/Z-Image.md), [Code](/examples/z_image/) + - FLUX.2-dev: [Model](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev), [Documentation](/docs/en/Model_Details/FLUX2.md), [Code](/examples/flux2/) + - Training framework upgrade + - [Split Training](/docs/zh/Training/Split_Training.md): Supports automatically splitting the training process into two stages: data processing and training (even for training ControlNet or any other model). Computations that do not require gradient backpropagation, such as text encoding and VAE encoding, are performed during the data processing stage, while other computations are handled during the training stage. Faster speed, less VRAM requirement. + - [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model. + - [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights. + +
+More + +- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos. + +- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project. + +- **October 27, 2025** Supported the [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, adding another member to the Wan model ecosystem. + +- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) released! This model was jointly developed and open-sourced by us and Taobao Experience Design Team. Built upon Qwen-Image, the model is specifically designed for e-commerce poster scenarios, supporting precise partition layout control. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py). + +- **September 9, 2025** Our training framework supports various training modes. Currently adapted for Qwen-Image, in addition to the standard SFT training mode, Direct Distill is now supported. Please refer to [our sample code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support more comprehensive model training functions. + +- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model. See [./examples/wanvideo/](./examples/wanvideo/). + +- **August 21, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) released! Compared to the V1 version, the training dataset has been changed to [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), so the generated images better conform to Qwen-Image's own image distribution and style. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py). + +- **August 21, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structural control LoRA model, adopting the In Context technical route, supporting multiple categories of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py). + +- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) + +- **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family! + +- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py). + +- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) dataset. This is an image dataset generated using the Qwen-Image model, containing 160,000 `1024 x 1024` images. It includes general, English text rendering, and Chinese text rendering subsets. We provide annotations for image descriptions, entities, and structural control images for each image. Developers can use this dataset to train Qwen-Image models' ControlNet and EliGen models. We aim to promote technological development through open-sourcing! + +- **August 13, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py). + +- **August 12, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py). + +- **August 11, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) for Qwen-Image, following the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure has been modified to LoRA, thus being better compatible with other open-source ecosystem models. + +- **August 7, 2025** We open-sourced the entity control LoRA model [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) for Qwen-Image. Qwen-Image-EliGen can achieve entity-level controlled text-to-image generation. Technical details can be found in [the paper](https://arxiv.org/abs/2501.01097). Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet). + +- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration. + +- **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family! + +- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/). + +- **July 28, 2025** Wan 2.2 open-sourced. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, and full training. For more details, please refer to [./examples/wanvideo/](./examples/wanvideo/). + +- **July 11, 2025** We propose Nexus-Gen, a unified framework that combines the language reasoning capabilities of Large Language Models (LLMs) with the image generation capabilities of diffusion models. This framework supports seamless image understanding, generation, and editing tasks. + - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) + - GitHub Repository: https://github.com/modelscope/Nexus-Gen + - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) + - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) + - Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) + +- **June 15, 2025** ModelScope's official evaluation framework [EvalScope](https://github.com/modelscope/evalscope) now supports text-to-image generation evaluation. Please refer to the [best practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide to try it out. + +- **March 25, 2025** Our new open-source project [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) is now open-sourced! Focused on stable model deployment, targeting industry, providing better engineering support, higher computational performance, and more stable features. + +- **March 31, 2025** We support InfiniteYou, a face feature preservation method for FLUX. More details can be found in [./examples/InfiniteYou/](./examples/InfiniteYou/). + +- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of Tencent's open-source HunyuanVideo. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/). + +- **February 25, 2025** We support Wan-Video, a series of state-of-the-art video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). + +- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! Advanced video synthesis model! See [./examples/stepvideo](./examples/stepvideo/). + +- **December 31, 2024** We propose EliGen, a new framework for entity-level controlled text-to-image generation, supplemented with an inpainting fusion pipeline, extending its capabilities to image inpainting tasks. EliGen can seamlessly integrate existing community models such as IP-Adapter and In-Context LoRA, enhancing their versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/). + - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097) + - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen) + - Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen) + - Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet) + +- **December 19, 2024** We implemented advanced VRAM management for HunyuanVideo, enabling video generation with resolutions of 129x720x1280 on 24GB VRAM or 129x512x384 on just 6GB VRAM. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/). + +- **December 18, 2024** We propose ArtAug, a method to improve text-to-image models through synthesis-understanding interaction. We trained an ArtAug enhancement module for FLUX.1-dev in LoRA format. This model incorporates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, thereby improving the quality of generated images. + - Paper: https://arxiv.org/abs/2412.12888 + - Example: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug + - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1) + - Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (coming soon) + +- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models and can be freely combined, even if their structures are different. Additionally, ControlNet models are compatible with high-resolution optimization and partition control technologies, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/). + +- **October 8, 2024** We released extended LoRAs based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1). + +- **August 22, 2024** This project now supports CogVideoX-5B. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including: + - Text-to-video + - Video editing + - Self super-resolution + - Video interpolation + +- **August 22, 2024** We implemented an interesting brush feature that supports all text-to-image models. Now you can create stunning images with the assistance of AI using the brush! + - Use it in our [WebUI](#usage-in-webui). + +- **August 21, 2024** DiffSynth-Studio now supports FLUX. + - Enable CFG and high-resolution inpainting to improve visual quality. See [here](/examples/image_synthesis/README.md) + - LoRA, ControlNet, and other addon models will be released soon. + +- **June 21, 2024** We propose ExVideo, a post-training fine-tuning technique aimed at enhancing the capabilities of video generation models. We extended Stable Video Diffusion to achieve long video generation of up to 128 frames. + - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) + - Source code has been released in this repository. See [`examples/ExVideo`](./examples/ExVideo/). + - Model has been released at [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1). + - Technical report has been released at [arXiv](https://arxiv.org/abs/2406.14130). + - You can try ExVideo in this [demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)! + +- **June 13, 2024** DiffSynth Studio has migrated to ModelScope. The development team has also transitioned from "me" to "us". Of course, I will still participate in subsequent development and maintenance work. + +- **January 29, 2024** We propose Diffutoon, an excellent cartoon coloring solution. + - [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/) + - Source code has been released in this project. + - Technical report (IJCAI 2024) has been released at [arXiv](https://arxiv.org/abs/2401.16224). + +- **December 8, 2023** We decided to initiate a new project aimed at unleashing the potential of diffusion models, especially in video synthesis. The development work of this project officially began. + +- **November 15, 2023** We propose FastBlend, a powerful video deflickering algorithm. + - sd-webui extension has been released at [GitHub](https://github.com/Artiprocher/sd-webui-fastblend). + - Demonstration videos have been showcased on Bilibili, including three tasks: + - [Video Deflickering](https://www.bilibili.com/video/BV1d94y1W7PE) + - [Video Interpolation](https://www.bilibili.com/video/BV1Lw411m71p) + - [Image-Driven Video Rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF) + - Technical report has been released at [arXiv](https://arxiv.org/abs/2311.09265). + - Unofficial ComfyUI extensions developed by other users have been released at [GitHub](https://github.com/AInseven/ComfyUI-fastblend). + +- **October 1, 2023** We released an early version of the project named FastSDXL. This was an initial attempt to build a diffusion engine. + - Source code has been released at [GitHub](https://github.com/Artiprocher/FastSDXL). + - FastSDXL includes a trainable OLSS scheduler to improve efficiency. + - The original repository of OLSS is located [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler). + - Technical report (CIKM 2023) has been released at [arXiv](https://arxiv.org/abs/2305.14677). + - Demonstration video has been released at [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj). + - Since OLSS requires additional training, we did not implement it in this project. + +- **August 29, 2023** We propose DiffSynth, a video synthesis framework. + - [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/). + - Source code has been released at [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth). + - Technical report (ECML PKDD 2024) has been released at [arXiv](https://arxiv.org/abs/2308.03463). + +
+ +## Installation + +Install from source (recommended): + +``` +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +
+Other installation methods + +Install from PyPI (version updates may be delayed; for latest features, install from source) + +``` +pip install diffsynth +``` + +If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages: + +* [torch](https://pytorch.org/get-started/locally/) +* [sentencepiece](https://github.com/google/sentencepiece) +* [cmake](https://cmake.org) +* [cupy](https://docs.cupy.dev/en/stable/install.html) + +
+ +## Basic Framework + +DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training. + +
+Environment Variable Configuration + +> Before running model inference or training, you can configure settings such as the model download source via [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md). +> +> By default, this project downloads models from ModelScope. For users outside China, you can configure the system to download models from the ModelScope international site as follows: +> +> ```python +> import os +> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai" +> ``` +> +> To download models from other sources, please modify the environment variable [DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source). + +
+ +### Image Synthesis + +![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d) + +#### Z-Image: [/docs/en/Model_Details/Z-Image.md](/docs/en/Model_Details/Z-Image.md) + +
+ +Quick Start + +Running the following code will quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model for inference. FP8 quantization significantly degrades image quality, so we do not recommend enabling any quantization for the Z-Image Turbo model. CPU offloading is recommended, and the model can run with as little as 8 GB of GPU memory. + +```python +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image.jpg") +``` + +
+ +
+ +Examples + +Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/) + +| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| + +
+ +#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md) + +
+ +Quick Start + +Running the following code will quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model for inference. VRAM management is enabled, and the framework automatically loads model parameters based on available GPU memory. The model can run with as little as 10 GB of VRAM. + +```python +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image.jpg") +``` + +
+ +
+ +Examples + +Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/) + +| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation | +|-|-|-|-|-| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| + +
+ +#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md) + +
+ +Quick Start + +Running the following code will quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +
+ +
+ +Model Lineage + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +
+ +Examples + +Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/qwen_image/) + +| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| + +
+ +#### FLUX.1: [/docs/en/Model_Details/FLUX.md](/docs/en/Model_Details/FLUX.md) + +
+ +Quick Start + +Running the following code will quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM. + +```python +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1, +) +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +image = pipe(prompt=prompt, seed=0) +image.save("image.jpg") +``` + +
+ +
+ +Model Lineage + +```mermaid +graph LR; + FLUX.1-Series-->black-forest-labs/FLUX.1-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev; + black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series; + FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta; + FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha; + FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler; + black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter; + black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev; + black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview; + black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit; + Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2; + Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2; +``` + +
+ +
+ +Examples + +Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/) + +| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| +|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| +|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| +|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| +|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| +|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| +|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| +|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)| +|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)| + +
+ +### Video Synthesis + +https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 + +#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md) + +
+ +Quick Start + +Running the following code will quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM. + +```python +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` + +
+ +
+ +Model Lineage + +```mermaid +graph LR; + Wan-Series-->Wan2.1-Series; + Wan-Series-->Wan2.2-Series; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P; + Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P; + Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview; + iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera; + Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1; + Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video; + Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video; + Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B; + Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B; + Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series; + Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera; +``` + +
+ +
+ +Examples + +Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) + +| Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| +|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| +|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| +|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| +|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| +|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| + +
+ +## Innovative Achievements + +DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. + +
+ +AttriCtrl: Attribute Intensity Control for Image Generation Models + +- Paper: [AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models](https://arxiv.org/abs/2508.02151) +- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py) +- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev) + +|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9| +|-|-|-|-|-| +|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)| + +
+ + +
+ +AutoLoRA: Automated LoRA Retrieval and Fusion + +- Paper: [AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation](https://arxiv.org/abs/2508.02107) +- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) +- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) + +||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)| +|-|-|-|-|-| +|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)| +|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)| +|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)| +|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)| + +
+ + +
+ +Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing + +- Detailed Page: https://github.com/modelscope/Nexus-Gen +- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) +- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) +- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) +- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) + +![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg) + +
+ + +
+ +ArtAug: Aesthetic Enhancement for Image Generation Models + +- Detailed Page: [./examples/ArtAug/](./examples/ArtAug/) +- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888) +- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1) +- Online Experience: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0) + +|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA| +|-|-| +|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)| + +
+ + +
+ +EliGen: Precise Image Partition Control + +- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097) +- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) +- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen) +- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen) +- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet) + +|Entity Control Region|Generated Image| +|-|-| +|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)| + +
+ + +
+ +ExVideo: Extended Training for Video Generation Models + +- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) +- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130) +- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo) +- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) + +https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc + +
+ + +
+ +Diffutoon: High-Resolution Anime-Style Video Rendering + +- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/) +- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224) +- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon) + +https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd + +
+ + +
+ +DiffSynth: The Original Version of This Project + +- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/) +- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463) +- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth) + +https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea + +
+ diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..cdfc72fcb054ff5f4b21aefd74834b40574656c9 --- /dev/null +++ b/README_zh.md @@ -0,0 +1,784 @@ +# DiffSynth-Studio + + modelscope%2FDiffSynth-Studio | Trendshift

+ +[![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/) +[![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE) +[![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues) +[![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/) +[![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/) + +[Switch to English](./README.md) + +## 简介 + +欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界! + +DiffSynth 目前包括两个开源项目: +* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。 +* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。 + +[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能: + +* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home +* ModelScope Civision (for global users): https://modelscope.ai/civision/home + +> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md) + +我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。 + +## 更新历史 + +> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。 + +> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 + +- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。 + +- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线 + - [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中 + - [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存 + - 新模型支持 + - Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/) + - FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/) + - 训练框架升级 + - [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型),在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。 + - [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。 + - [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。 + +
+更多 + +- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。 + +- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。 + +- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。 + +- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。 + +- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。 + +- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。 + +- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。 + +- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。 + +- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) + +- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员! + +- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。 + +- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展! + +- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。 + +- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。 + +- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。 + +- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。 + +- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。 + +- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! + +- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。 + +- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 + +- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 + - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) + - Github 仓库: https://github.com/modelscope/Nexus-Gen + - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) + - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) + - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) + +- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。 + +- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。 + +- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。 + +- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。 + +- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。 + +- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。 + +- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。 + - 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097) + - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen) + - 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen) + - 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet) + +- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。 + +- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。 + - 论文: https://arxiv.org/abs/2412.12888 + - 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug + - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1) + - 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线) + +- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。 + +- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。 + +- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括: + - 文本到视频 + - 视频编辑 + - 自我超分 + - 视频插帧 + +- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了! + - 在我们的 [WebUI](#usage-in-webui) 中使用它。 + +- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。 + - 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md) + - LoRA、ControlNet 和其他附加模型将很快推出。 + +- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。 + - [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/) + - 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。 + - 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。 + - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。 + - 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo! + +- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。 + +- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。 + - [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/) + - 源代码已在此项目中发布。 + - 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。 + +- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。 + +- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。 + - sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。 + - 演示视频已在 Bilibili 上展示,包含三个任务: + - [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE) + - [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p) + - [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF) + - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。 + - 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。 + +- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。 + - 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。 + - FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。 + - OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。 + - 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。 + - 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。 + - 由于 OLSS 需要额外训练,我们未在本项目中实现它。 + +- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。 + - [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。 + - 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。 + - 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。 + +
+ +## 安装 + +从源码安装(推荐): + +``` +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +
+其他安装方式 + +从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装) + +``` +pip install diffsynth +``` + +如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档: + +* [torch](https://pytorch.org/get-started/locally/) +* [sentencepiece](https://github.com/google/sentencepiece) +* [cmake](https://cmake.org) +* [cupy](https://docs.cupy.dev/en/stable/install.html) + +
+ +## 基础框架 + +DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。 + +
+环境变量配置 + +> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。 +> +> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型: +> +> ```python +> import os +> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai" +> ``` +> +> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。 + +
+ +### 图像生成模型 + +![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d) + +#### Z-Image:[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。 + +```python +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image.jpg") +``` + +
+ +
+ +示例代码 + +Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/) + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| + +
+ +#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。 + +```python +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image.jpg") +``` + +
+ +
+ +示例代码 + +FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/) + +|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| + +
+ +#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +
+ +
+ +模型血缘 + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +
+ +示例代码 + +Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/) + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| + +
+ +#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1, +) +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +image = pipe(prompt=prompt, seed=0) +image.save("image.jpg") +``` + +
+ +
+ +模型血缘 + +```mermaid +graph LR; + FLUX.1-Series-->black-forest-labs/FLUX.1-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev; + black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series; + FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta; + FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha; + FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler; + black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter; + black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev; + black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview; + black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit; + Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2; + Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2; +``` + +
+ +
+ +示例代码 + +FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/) + +|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| +|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| +|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| +|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| +|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| +|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| +|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| +|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)| +|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)| + +
+ +### 视频生成模型 + +https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 + +#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` + +
+ +
+ +模型血缘 + +```mermaid +graph LR; + Wan-Series-->Wan2.1-Series; + Wan-Series-->Wan2.2-Series; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P; + Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P; + Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview; + iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera; + Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1; + Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video; + Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video; + Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B; + Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B; + Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series; + Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera; +``` + +
+ +
+ +示例代码 + +Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) + +|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| +|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| +|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| +|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| +|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| +|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| + +
+ +## 创新成果 + +DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 + +
+ +AttriCtrl: 图像生成模型的属性强度控制 + +- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models +](https://arxiv.org/abs/2508.02151) +- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py) +- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev) + +|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9| +|-|-|-|-|-| +|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)| + +
+ + +
+ +AutoLoRA: 自动化的 LoRA 检索和融合 + +- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation +](https://arxiv.org/abs/2508.02107) +- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) +- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) + +||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)| +|-|-|-|-|-| +|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)| +|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)| +|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)| +|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)| + +
+ + +
+ +Nexus-Gen: 统一架构的图像理解、生成、编辑 + +- 详细页面:https://github.com/modelscope/Nexus-Gen +- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) +- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) +- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) +- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) + +![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg) + +
+ + +
+ +ArtAug: 图像生成模型的美学提升 + +- 详细页面:[./examples/ArtAug/](./examples/ArtAug/) +- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888) +- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1) +- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0) + +|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA| +|-|-| +|![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)| + +
+ + +
+ +EliGen: 精准的图像分区控制 + +- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097) +- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) +- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen) +- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen) +- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet) + +|实体控制区域|生成图像| +|-|-| +|![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)| + +
+ + +
+ +ExVideo: 视频生成模型的扩展训练 + +- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) +- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130) +- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看 +- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) + +https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc + +
+ + +
+ +Diffutoon: 高分辨率动漫风格视频渲染 + +- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/) +- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224) +- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看 + +https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd + +
+ + +
+ +DiffSynth: 本项目的初代版本 + +- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/) +- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463) +- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看 + +https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea + +
diff --git a/assets/egg.mp4 b/assets/egg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1bc6a90cc174959689180369186069f80380f84e --- /dev/null +++ b/assets/egg.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96647e399f3a7772d32bee1562a0c01fe8b273f6ad5fe2708cc10878972fce04 +size 4810228 diff --git a/comp_attn_bbox_layout.png b/comp_attn_bbox_layout.png new file mode 100644 index 0000000000000000000000000000000000000000..a76c004b603454a2a2d69f3b16e3fec2a1c541c4 Binary files /dev/null and b/comp_attn_bbox_layout.png differ diff --git a/comp_attn_trajectory.png b/comp_attn_trajectory.png new file mode 100644 index 0000000000000000000000000000000000000000..458d73e69c8b2c921e3b7d0d282f2b5f0edb5c44 Binary files /dev/null and b/comp_attn_trajectory.png differ diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb67a43fa4e5791ab58e7e40260bc3df8b6bc7cc --- /dev/null +++ b/diffsynth/__init__.py @@ -0,0 +1 @@ +from .core import * diff --git a/diffsynth/configs/__init__.py b/diffsynth/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..144a822978b12a8297e341760a74791fb12802c0 --- /dev/null +++ b/diffsynth/configs/__init__.py @@ -0,0 +1,2 @@ +from .model_configs import MODEL_CONFIGS +from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..dca078a0d75faa8e23dd4ad40d272fcda1a032fa --- /dev/null +++ b/diffsynth/configs/model_configs.py @@ -0,0 +1,518 @@ +qwen_image_series = [ + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors") + "model_hash": "0319a1cb19835fb510907dd3367c95ff", + "model_name": "qwen_image_dit", + "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT", + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "8004730443f55db63092006dd9f7110e", + "model_name": "qwen_image_text_encoder", + "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "ed4ea5824d55ec3107b09815e318123a", + "model_name": "qwen_image_vae", + "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors") + "model_hash": "073bce9cf969e317e5662cd570c3e79c", + "model_name": "qwen_image_blockwise_controlnet", + "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors") + "model_hash": "a9e54e480a628f0b956a688a81c33bab", + "model_name": "qwen_image_blockwise_controlnet", + "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", + "extra_kwargs": {"additional_in_dim": 4}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors") + "model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8", + "model_name": "siglip2_image_encoder", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors") + "model_hash": "5722b5c873720009de96422993b15682", + "model_name": "dinov3_image_encoder", + "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder", + }, + { + # Example: + "model_hash": "a166c33455cdbd89c0888a3645ca5c0f", + "model_name": "qwen_image_image2lora_coarse", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + }, + { + # Example: + "model_hash": "a5476e691767a4da6d3a6634a10f7408", + "model_name": "qwen_image_image2lora_fine", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64} + }, + { + # Example: + "model_hash": "0aad514690602ecaff932c701cb4b0bb", + "model_name": "qwen_image_image2lora_style", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 64, "use_residual": False} + }, +] + +wan_series = [ + { + # Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors") + "model_hash": "5ec04e02b42d2580483ad69f4e76346a", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth") + "model_hash": "9c8818c2cbea55eca56c7b447df170da", + "model_name": "wan_video_text_encoder", + "model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth") + "model_hash": "ccc42284ea13e1ad04693284c7a09be6", + "model_name": "wan_video_vae", + "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors") + "model_hash": "8b27900f680d7251ce44e2dc8ae1ffef", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel", + }, + { + # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "5f90e66a0672219f12d9a626c8c21f61", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers" + }, + { + # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "5f90e66a0672219f12d9a626c8c21f61", + "model_name": "wan_video_vap", + "model_class": "diffsynth.models.wan_video_mot.MotWanModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + "model_hash": "5941c53e207d62f20f9025686193c40b", + "model_name": "wan_video_image_encoder", + "model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter" + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors") + "model_hash": "dbd5ec76bbf977983f972c151d545389", + "model_name": "wan_video_motion_controller", + "model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "9269f8db9040a9d860eaca435be61814", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "349723183fc063b2bfc10bb2835cf677", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "6d6ccde6845b95ad9114ab993d917893", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "efa44cddf936c70abd0ea28b6cbe946c", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "6bfcfb3b342cb286ce886889d519a77e", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "70ddad9d3a133785da5ea371aae09504", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "b61c605c2adbd23124d152ed28e049ae", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "26bde73488a92e64cc20b0a7485b9e5b", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "a61453409b67cd3246cf0c3bebad47ba", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "a61453409b67cd3246cf0c3bebad47ba", + "model_name": "wan_video_vace", + "model_class": "diffsynth.models.wan_video_vace.VaceWanModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "7a513e1f257a861512b1afd387a8ecd9", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "7a513e1f257a861512b1afd387a8ecd9", + "model_name": "wan_video_vace", + "model_class": "diffsynth.models.wan_video_vace.VaceWanModel", + "extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "31fa352acb8a1b1d33cd8764273d80a2", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "31fa352acb8a1b1d33cd8764273d80a2", + "model_name": "wan_video_animate_adapter", + "model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter" + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "47dbeab5e560db3180adf51dc0232fb1", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "2267d489f0ceb9f21836532952852ee5", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False}, + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "5b013604280dd715f8457c6ed6d6a626", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "966cffdcc52f9c46c391768b27637614", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel", + "extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "1f5ab7703c6fc803fdded85ff040c316", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth") + "model_hash": "e1de6c02cdac79f8b739f4d3698cd216", + "model_name": "wan_video_vae", + "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors") + "model_hash": "06be60f3a4526586d8431cd038a71486", + "model_name": "wans2v_audio_encoder", + "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter", + }, +] + +flux_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors") + "model_hash": "a29710fea6dddb0314663ee823598e50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", + "model_name": "flux_text_encoder_clip", + "model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors") + "model_hash": "22540b49eaedbc2f2784b2091a234c7c", + "model_name": "flux_text_encoder_t5", + "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors") + "model_hash": "d02f41c13549fa5093d3521f62a5570a", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "extra_kwargs": {'input_dim': 196, 'num_blocks': 8}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + "model_hash": "0629116fce1472503a66992f96f3eb1a", + "model_name": "flux_value_controller", + "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder", + }, + { + # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "52357cb26250681367488a8954c271e8", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}, + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "78d18b9101345ff695f312e7e62538c0", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}, + }, + { + # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "b001c89139b5f053c715fe772362dd2a", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_single_blocks": 0}, + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin") + "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481", + "model_name": "infiniteyou_image_projector", + "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors") + "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors") + "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab", + "model_name": "flux_lora_encoder", + "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors") + "model_hash": "30143afb2dea73d1ac580e0787628f8c", + "model_name": "flux_lora_patcher", + "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors") + "model_hash": "2bd19e845116e4f875a0a048e27fc219", + "model_name": "nexus_gen_llm", + "model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "nexus_gen_editing_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "nexus_gen_generation_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin") + "model_hash": "4daaa66cc656a8fe369908693dad0a35", + "model_name": "flux_ipadapter", + "model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors") + "model_hash": "04d8c1e20a1f1b25f7434f111992a33f", + "model_name": "siglip_vision_model", + "model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50", + "model_name": "step1x_connector", + "model_class": "diffsynth.models.step1x_connector.Qwen2Connector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + "extra_kwargs": {"disable_guidance_embedder": True}, + }, +] + +flux2_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "28fca3d8e5bf2a2d1271748a773f6757", + "model_name": "flux2_text_encoder", + "model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors") + "model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "c54288e3ee12ca215898840682337b95", + "model_name": "flux2_vae", + "model_class": "diffsynth.models.flux2_vae.Flux2VAE", + }, +] + +z_image_series = [ + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors") + "model_hash": "fc3a8a1247fe185ce116ccbe0e426c28", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "0f050f62a88876fea6eae0a18dac5a2e", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors") + "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers", + "extra_kwargs": {"use_conv_attention": False}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors") + "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", + "extra_kwargs": {"use_conv_attention": False}, + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..958dad42a08527b922c39070f2d38ff121f81c2a --- /dev/null +++ b/diffsynth/configs/vram_management_module_maps.py @@ -0,0 +1,197 @@ +flux_general_vram_config = { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule", +} + +VRAM_MANAGEMENT_MODULE_MAPS = { + "diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_vae.QwenImageVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": { + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": { + "diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_dit_s2v.WanS2VModel": { + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_dit.WanModel": { + "diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule", + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_image_encoder.WanImageEncoder": { + "diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_mot.MotWanModel": { + "diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.wan_video_text_encoder.WanTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vace.VaceWanModel": { + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vae.WanVideoVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vae.WanVideoVAE38": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wav2vec.WanS2VAudioEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_dit.FluxDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config, + "diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config, + "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config, + "diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config, + "diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config, + "diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config, + "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config, + "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_dit.Flux2DiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_text_encoder.Flux2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_vae.Flux2VAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_text_encoder.ZImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_dit.ZImageDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, +} diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72e501fc500321f0710c3f507930d301ab847033 --- /dev/null +++ b/diffsynth/core/__init__.py @@ -0,0 +1,5 @@ +from .attention import * +from .data import * +from .gradient import * +from .loader import * +from .vram import * diff --git a/diffsynth/core/attention/__init__.py b/diffsynth/core/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45cf8a4382397aa4ff6558f37191c726d821b95a --- /dev/null +++ b/diffsynth/core/attention/__init__.py @@ -0,0 +1 @@ +from .attention import attention_forward diff --git a/diffsynth/core/attention/attention.py b/diffsynth/core/attention/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..15b55a43aa0ee248245fd6652f731f966091b6f7 --- /dev/null +++ b/diffsynth/core/attention/attention.py @@ -0,0 +1,121 @@ +import torch, os +from einops import rearrange + + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +try: + import xformers.ops as xops + XFORMERS_AVAILABLE = True +except ModuleNotFoundError: + XFORMERS_AVAILABLE = False + + +def initialize_attention_priority(): + if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None: + return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower() + elif FLASH_ATTN_3_AVAILABLE: + return "flash_attention_3" + elif FLASH_ATTN_2_AVAILABLE: + return "flash_attention_2" + elif SAGE_ATTN_AVAILABLE: + return "sage_attention" + elif XFORMERS_AVAILABLE: + return "xformers" + else: + return "torch" + + +ATTENTION_IMPLEMENTATION = initialize_attention_priority() + + +def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None): + dims = {} if dims is None else dims + if q_pattern != required_in_pattern: + q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims) + if k_pattern != required_in_pattern: + k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) + if v_pattern != required_in_pattern: + v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) + return q, k, v + + +def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None): + dims = {} if dims is None else dims + if out_pattern != required_out_pattern: + out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims) + return out + + +def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None): + required_in_pattern, required_out_pattern= "b n s d", "b n s d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale) + if isinstance(out, tuple): + out = out[0] + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b n s d", "b n s d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = sageattn(q, k, v, sm_scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = xops.memory_efficient_attention(q, k, v, scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False): + if compatibility_mode or (attn_mask is not None): + return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale) + else: + if ATTENTION_IMPLEMENTATION == "flash_attention_3": + return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "flash_attention_2": + return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "sage_attention": + return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "xformers": + return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + else: + return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) diff --git a/diffsynth/core/data/__init__.py b/diffsynth/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d494a277d81eeb2a9575155eb983d8bc3879590a --- /dev/null +++ b/diffsynth/core/data/__init__.py @@ -0,0 +1 @@ +from .unified_dataset import UnifiedDataset diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..c14944ea3de15dafaaff34b5badcb87c8839f6da --- /dev/null +++ b/diffsynth/core/data/operators.py @@ -0,0 +1,218 @@ +import torch, torchvision, imageio, os +import imageio.v3 as iio +from PIL import Image + + +class DataProcessingPipeline: + def __init__(self, operators=None): + self.operators: list[DataProcessingOperator] = [] if operators is None else operators + + def __call__(self, data): + for operator in self.operators: + data = operator(data) + return data + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline(self.operators + pipe.operators) + + +class DataProcessingOperator: + def __call__(self, data): + raise NotImplementedError("DataProcessingOperator cannot be called directly.") + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline([self]).__rshift__(pipe) + + +class DataProcessingOperatorRaw(DataProcessingOperator): + def __call__(self, data): + return data + + +class ToInt(DataProcessingOperator): + def __call__(self, data): + return int(data) + + +class ToFloat(DataProcessingOperator): + def __call__(self, data): + return float(data) + + +class ToStr(DataProcessingOperator): + def __init__(self, none_value=""): + self.none_value = none_value + + def __call__(self, data): + if data is None: data = self.none_value + return str(data) + + +class LoadImage(DataProcessingOperator): + def __init__(self, convert_RGB=True): + self.convert_RGB = convert_RGB + + def __call__(self, data: str): + image = Image.open(data) + if self.convert_RGB: image = image.convert("RGB") + return image + + +class ImageCropAndResize(DataProcessingOperator): + def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1): + self.height = height + self.width = width + self.max_pixels = max_pixels + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def get_height_width(self, image): + if self.height is None or self.width is None: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + def __call__(self, data: Image.Image): + image = self.crop_and_resize(data, *self.get_height_width(data)) + return image + + +class ToList(DataProcessingOperator): + def __call__(self, data): + return [data] + + +class LoadVideo(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, reader): + num_frames = self.num_frames + if int(reader.count_frames()) < num_frames: + num_frames = int(reader.count_frames()) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + reader = imageio.get_reader(data) + num_frames = self.get_num_frames(reader) + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(frame_id) + frame = Image.fromarray(frame) + frame = self.frame_processor(frame) + frames.append(frame) + reader.close() + return frames + + +class SequencialProcess(DataProcessingOperator): + def __init__(self, operator=lambda x: x): + self.operator = operator + + def __call__(self, data): + return [self.operator(i) for i in data] + + +class LoadGIF(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, path): + num_frames = self.num_frames + images = iio.imread(path, mode="RGB") + if len(images) < num_frames: + num_frames = len(images) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + num_frames = self.get_num_frames(data) + frames = [] + images = iio.imread(data, mode="RGB") + for img in images: + frame = Image.fromarray(img) + frame = self.frame_processor(frame) + frames.append(frame) + if len(frames) >= num_frames: + break + return frames + + +class RouteByExtensionName(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data: str): + file_ext_name = data.split(".")[-1].lower() + for ext_names, operator in self.operator_map: + if ext_names is None or file_ext_name in ext_names: + return operator(data) + raise ValueError(f"Unsupported file: {data}") + + +class RouteByType(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data): + for dtype, operator in self.operator_map: + if dtype is None or isinstance(data, dtype): + return operator(data) + raise ValueError(f"Unsupported data: {data}") + + +class LoadTorchPickle(DataProcessingOperator): + def __init__(self, map_location="cpu"): + self.map_location = map_location + + def __call__(self, data): + return torch.load(data, map_location=self.map_location, weights_only=False) + + +class ToAbsolutePath(DataProcessingOperator): + def __init__(self, base_path=""): + self.base_path = base_path + + def __call__(self, data): + return os.path.join(self.base_path, data) + + +class LoadAudio(DataProcessingOperator): + def __init__(self, sr=16000): + self.sr = sr + def __call__(self, data: str): + import librosa + input_audio, sample_rate = librosa.load(data, sr=self.sr) + return input_audio diff --git a/diffsynth/core/data/unified_dataset.py b/diffsynth/core/data/unified_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..074208cf81a5c77bbd54ea08f4699ce63891916e --- /dev/null +++ b/diffsynth/core/data/unified_dataset.py @@ -0,0 +1,112 @@ +from .operators import * +import torch, json, pandas + + +class UnifiedDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + repeat=1, + data_file_keys=tuple(), + main_data_operator=lambda x: x, + special_operator_map=None, + ): + self.base_path = base_path + self.metadata_path = metadata_path + self.repeat = repeat + self.data_file_keys = data_file_keys + self.main_data_operator = main_data_operator + self.cached_data_operator = LoadTorchPickle() + self.special_operator_map = {} if special_operator_map is None else special_operator_map + self.data = [] + self.cached_data = [] + self.load_from_cache = metadata_path is None + self.load_metadata(metadata_path) + + @staticmethod + def default_image_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), + ]) + + @staticmethod + def default_video_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + num_frames=81, time_division_factor=4, time_division_remainder=1, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), + ]) + + def search_for_cached_data_files(self, path): + for file_name in os.listdir(path): + subpath = os.path.join(path, file_name) + if os.path.isdir(subpath): + self.search_for_cached_data_files(subpath) + elif subpath.endswith(".pth"): + self.cached_data.append(subpath) + + def load_metadata(self, metadata_path): + if metadata_path is None: + print("No metadata_path. Searching for cached data files.") + self.search_for_cached_data_files(self.base_path) + print(f"{len(self.cached_data)} cached data files found.") + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata + elif metadata_path.endswith(".jsonl"): + metadata = [] + with open(metadata_path, 'r') as f: + for line in f: + metadata.append(json.loads(line.strip())) + self.data = metadata + else: + metadata = pandas.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + def __getitem__(self, data_id): + if self.load_from_cache: + data = self.cached_data[data_id % len(self.cached_data)] + data = self.cached_data_operator(data) + else: + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + if key in self.special_operator_map: + data[key] = self.special_operator_map[key](data[key]) + elif key in self.data_file_keys: + data[key] = self.main_data_operator(data[key]) + return data + + def __len__(self): + if self.load_from_cache: + return len(self.cached_data) * self.repeat + else: + return len(self.data) * self.repeat + + def check_data_equal(self, data1, data2): + # Debug only + if len(data1) != len(data2): + return False + for k in data1: + if data1[k] != data2[k]: + return False + return True diff --git a/diffsynth/core/gradient/__init__.py b/diffsynth/core/gradient/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57914792a78ec32f69c3c99ae37535598efc8d52 --- /dev/null +++ b/diffsynth/core/gradient/__init__.py @@ -0,0 +1 @@ +from .gradient_checkpoint import gradient_checkpoint_forward diff --git a/diffsynth/core/gradient/gradient_checkpoint.py b/diffsynth/core/gradient/gradient_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b356415a004f3d74afdd45840f1fc4caf6659e16 --- /dev/null +++ b/diffsynth/core/gradient/gradient_checkpoint.py @@ -0,0 +1,34 @@ +import torch + + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + else: + model_output = model(*args, **kwargs) + return model_output diff --git a/diffsynth/core/loader/__init__.py b/diffsynth/core/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f56d814bae40436f66bca583e33d180d6e11247 --- /dev/null +++ b/diffsynth/core/loader/__init__.py @@ -0,0 +1,3 @@ +from .file import load_state_dict, hash_state_dict_keys, hash_model_file +from .model import load_model, load_model_with_disk_offload +from .config import ModelConfig diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py new file mode 100644 index 0000000000000000000000000000000000000000..562675f3055cb4f4644f7bc4c524f3cbf6a54f76 --- /dev/null +++ b/diffsynth/core/loader/config.py @@ -0,0 +1,117 @@ +import torch, glob, os +from typing import Optional, Union +from dataclasses import dataclass +from modelscope import snapshot_download +from huggingface_hub import snapshot_download as hf_snapshot_download +from typing import Optional + + +@dataclass +class ModelConfig: + path: Union[str, list[str]] = None + model_id: str = None + origin_file_pattern: Union[str, list[str]] = None + download_source: str = None + local_model_path: str = None + skip_download: bool = None + offload_device: Optional[Union[str, torch.device]] = None + offload_dtype: Optional[torch.dtype] = None + onload_device: Optional[Union[str, torch.device]] = None + onload_dtype: Optional[torch.dtype] = None + preparing_device: Optional[Union[str, torch.device]] = None + preparing_dtype: Optional[torch.dtype] = None + computation_device: Optional[Union[str, torch.device]] = None + computation_dtype: Optional[torch.dtype] = None + clear_parameters: bool = False + + def check_input(self): + if self.path is None and self.model_id is None: + raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") + + def parse_original_file_pattern(self): + if self.origin_file_pattern is None or self.origin_file_pattern == "": + return "*" + elif self.origin_file_pattern.endswith("/"): + return self.origin_file_pattern + "*" + else: + return self.origin_file_pattern + + def parse_download_source(self): + if self.download_source is None: + if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None: + return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') + else: + return "modelscope" + else: + return self.download_source + + def parse_skip_download(self): + if self.skip_download is None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true": + return True + elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false": + return False + else: + return False + else: + return self.skip_download + + def download(self): + origin_file_pattern = self.parse_original_file_pattern() + downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) + download_source = self.parse_download_source() + if download_source.lower() == "modelscope": + snapshot_download( + self.model_id, + local_dir=os.path.join(self.local_model_path, self.model_id), + allow_file_pattern=origin_file_pattern, + ignore_file_pattern=downloaded_files, + local_files_only=False + ) + elif download_source.lower() == "huggingface": + hf_snapshot_download( + self.model_id, + local_dir=os.path.join(self.local_model_path, self.model_id), + allow_patterns=origin_file_pattern, + ignore_patterns=downloaded_files, + local_files_only=False + ) + else: + raise ValueError("`download_source` should be `modelscope` or `huggingface`.") + + def require_downloading(self): + if self.path is not None: + return False + skip_download = self.parse_skip_download() + return not skip_download + + def reset_local_model_path(self): + if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None: + self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') + elif self.local_model_path is None: + self.local_model_path = "./models" + + def download_if_necessary(self): + self.check_input() + self.reset_local_model_path() + if self.require_downloading(): + self.download() + if self.origin_file_pattern is None or self.origin_file_pattern == "": + self.path = os.path.join(self.local_model_path, self.model_id) + else: + self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) + if isinstance(self.path, list) and len(self.path) == 1: + self.path = self.path[0] + + def vram_config(self): + return { + "offload_device": self.offload_device, + "offload_dtype": self.offload_dtype, + "onload_device": self.onload_device, + "onload_dtype": self.onload_dtype, + "preparing_device": self.preparing_device, + "preparing_dtype": self.preparing_dtype, + "computation_device": self.computation_device, + "computation_dtype": self.computation_dtype, + } diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py new file mode 100644 index 0000000000000000000000000000000000000000..8f66961f25d4fc547a2ec638f9d6a93be851afb9 --- /dev/null +++ b/diffsynth/core/loader/file.py @@ -0,0 +1,121 @@ +from safetensors import safe_open +import torch, hashlib + + +def load_state_dict(file_path, torch_dtype=None, device="cpu"): + if isinstance(file_path, list): + state_dict = {} + for file_path_ in file_path: + state_dict.update(load_state_dict(file_path_, torch_dtype, device)) + return state_dict + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): + state_dict = {} + with safe_open(file_path, framework="pt", device=str(device)) as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): + state_dict = torch.load(file_path, map_location=device, weights_only=True) + if len(state_dict) == 1: + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + elif "module" in state_dict: + state_dict = state_dict["module"] + elif "model_state" in state_dict: + state_dict = state_dict["model_state"] + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + + +def load_keys_dict(file_path): + if isinstance(file_path, list): + state_dict = {} + for file_path_ in file_path: + state_dict.update(load_keys_dict(file_path_)) + return state_dict + if file_path.endswith(".safetensors"): + return load_keys_dict_from_safetensors(file_path) + else: + return load_keys_dict_from_bin(file_path) + + +def load_keys_dict_from_safetensors(file_path): + keys_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + keys_dict[k] = f.get_slice(k).get_shape() + return keys_dict + + +def convert_state_dict_to_keys_dict(state_dict): + keys_dict = {} + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + keys_dict[k] = list(v.shape) + else: + keys_dict[k] = convert_state_dict_to_keys_dict(v) + return keys_dict + + +def load_keys_dict_from_bin(file_path): + state_dict = load_state_dict_from_bin(file_path) + keys_dict = convert_state_dict_to_keys_dict(state_dict) + return keys_dict + + +def convert_keys_dict_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, dict): + keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape)) + else: + if with_shape: + shape = "_".join(map(str, list(value))) + keys.append(key + ":" + shape) + keys.append(key) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def hash_model_file(path, with_shape=True): + keys_dict = load_keys_dict(path) + keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py new file mode 100644 index 0000000000000000000000000000000000000000..56fa7d362e72770a9e5a1a09b3e7900f7860283a --- /dev/null +++ b/diffsynth/core/loader/model.py @@ -0,0 +1,79 @@ +from ..vram.initialization import skip_model_initialization +from ..vram.disk_map import DiskMap +from ..vram.layers import enable_vram_management +from .file import load_state_dict +import torch + + +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): + config = {} if config is None else config + # Why do we use `skip_model_initialization`? + # It skips the random initialization of model parameters, + # thereby speeding up model loading and avoiding excessive memory usage. + with skip_model_initialization(): + model = model_class(**config) + # What is `module_map`? + # This is a module mapping table for VRAM management. + if module_map is not None: + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] + device = [d for d in devices if d != "disk"][0] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] + dtype = [d for d in dtypes if d != "disk"][0] + if vram_config["offload_device"] != "disk": + state_dict = DiskMap(path, device, torch_dtype=dtype) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} + model.load_state_dict(state_dict, assign=True) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) + else: + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) + else: + # Why do we use `DiskMap`? + # Sometimes a model file contains multiple models, + # and DiskMap can load only the parameters of a single model, + # avoiding the need to load all parameters in the file. + if use_disk_map: + state_dict = DiskMap(path, device, torch_dtype=torch_dtype) + else: + state_dict = load_state_dict(path, torch_dtype, device) + # Why do we use `state_dict_converter`? + # Some models are saved in complex formats, + # and we need to convert the state dict into the appropriate format. + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} + model.load_state_dict(state_dict, assign=True) + # Why do we call `to()`? + # Because some models override the behavior of `to()`, + # especially those from libraries like Transformers. + model = model.to(dtype=torch_dtype, device=device) + if hasattr(model, "eval"): + model = model.eval() + return model + + +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): + if isinstance(path, str): + path = [path] + config = {} if config is None else config + with skip_model_initialization(): + model = model_class(**config) + if hasattr(model, "eval"): + model = model.eval() + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": device, + "computation_dtype": torch_dtype, + "computation_device": device, + } + enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) + return model diff --git a/diffsynth/core/vram/__init__.py b/diffsynth/core/vram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32763bb9b4abfa4d5b2617827661c520d7e9fcae --- /dev/null +++ b/diffsynth/core/vram/__init__.py @@ -0,0 +1,2 @@ +from .initialization import skip_model_initialization +from .layers import * diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py new file mode 100644 index 0000000000000000000000000000000000000000..a666590fa99a9cc4de05dc3f5fa84c212e43de38 --- /dev/null +++ b/diffsynth/core/vram/disk_map.py @@ -0,0 +1,93 @@ +from safetensors import safe_open +import torch, os + + +class SafetensorsCompatibleTensor: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return list(self.tensor.shape) + + +class SafetensorsCompatibleBinaryLoader: + def __init__(self, path, device): + print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.") + self.state_dict = torch.load(path, weights_only=True, map_location=device) + + def keys(self): + return self.state_dict.keys() + + def get_tensor(self, name): + return self.state_dict[name] + + def get_slice(self, name): + return SafetensorsCompatibleTensor(self.state_dict[name]) + + +class DiskMap: + + def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): + self.path = path if isinstance(path, list) else [path] + self.device = device + self.torch_dtype = torch_dtype + if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: + self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE')) + else: + self.buffer_size = buffer_size + self.files = [] + self.flush_files() + self.name_map = {} + for file_id, file in enumerate(self.files): + for name in file.keys(): + self.name_map[name] = file_id + self.rename_dict = self.fetch_rename_dict(state_dict_converter) + + def flush_files(self): + if len(self.files) == 0: + for path in self.path: + if path.endswith(".safetensors"): + self.files.append(safe_open(path, framework="pt", device=str(self.device))) + else: + self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device)) + else: + for i, path in enumerate(self.path): + if path.endswith(".safetensors"): + self.files[i] = safe_open(path, framework="pt", device=str(self.device)) + self.num_params = 0 + + def __getitem__(self, name): + if self.rename_dict is not None: name = self.rename_dict[name] + file_id = self.name_map[name] + param = self.files[file_id].get_tensor(name) + if self.torch_dtype is not None and isinstance(param, torch.Tensor): + param = param.to(self.torch_dtype) + if isinstance(param, torch.Tensor) and param.device == "cpu": + param = param.clone() + if isinstance(param, torch.Tensor): + self.num_params += param.numel() + if self.num_params > self.buffer_size: + self.flush_files() + return param + + def fetch_rename_dict(self, state_dict_converter): + if state_dict_converter is None: + return None + state_dict = {} + for file in self.files: + for name in file.keys(): + state_dict[name] = name + state_dict = state_dict_converter(state_dict) + return state_dict + + def __iter__(self): + if self.rename_dict is not None: + return self.rename_dict.__iter__() + else: + return self.name_map.__iter__() + + def __contains__(self, x): + if self.rename_dict is not None: + return x in self.rename_dict + else: + return x in self.name_map diff --git a/diffsynth/core/vram/initialization.py b/diffsynth/core/vram/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..bff2498b526638bfdd1c114c78aa0b98c251a47d --- /dev/null +++ b/diffsynth/core/vram/initialization.py @@ -0,0 +1,21 @@ +import torch +from contextlib import contextmanager + + +@contextmanager +def skip_model_initialization(device=torch.device("meta")): + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + old_register_parameter = torch.nn.Module.register_parameter + torch.nn.Module.register_parameter = register_empty_parameter + try: + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..852a8f33855f4906b26d4bbf177e9ea7b4725164 --- /dev/null +++ b/diffsynth/core/vram/layers.py @@ -0,0 +1,475 @@ +import torch, copy +from typing import Union +from .initialization import skip_model_initialization +from .disk_map import DiskMap + + +class AutoTorchModule(torch.nn.Module): + + def __init__( + self, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + ): + super().__init__() + self.set_dtype_and_device( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.state = 0 + self.name = "" + + def set_dtype_and_device( + self, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + ): + self.offload_dtype = offload_dtype or computation_dtype + self.offload_device = offload_device or computation_dtype + self.onload_dtype = onload_dtype or computation_dtype + self.onload_device = onload_device or computation_dtype + self.preparing_dtype = preparing_dtype or computation_dtype + self.preparing_device = preparing_device or computation_dtype + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.vram_limit = vram_limit + + def cast_to(self, weight, dtype, device): + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def check_free_vram(self): + gpu_mem_state = torch.cuda.mem_get_info(self.computation_device) + used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) + return used_memory < self.vram_limit + + def offload(self): + if self.state != 0: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state != 1: + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def param_name(self, name): + if self.name == "": + return name + else: + return self.name + "." + name + + +class AutoWrappedModule(AutoTorchModule): + + def __init__( + self, + module: torch.nn.Module, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + super().__init__( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.module = module + if offload_dtype == "disk": + self.name = name + self.disk_map = disk_map + self.required_params = [name for name, _ in self.module.named_parameters()] + self.disk_offload = True + else: + self.disk_offload = False + + def load_from_disk(self, torch_dtype, device, copy_module=False): + if copy_module: + module = copy.deepcopy(self.module) + else: + module = self.module + state_dict = {} + for name in self.required_params: + param = self.disk_map[self.param_name(name)] + param = param.to(dtype=torch_dtype, device=device) + state_dict[name] = param + module.load_state_dict(state_dict, assign=True) + module.to(dtype=torch_dtype, device=device) + return module + + def offload_to_disk(self, model: torch.nn.Module): + for buf in model.buffers(): + # If there are some parameters are registed in buffers (not in state dict), + # We cannot offload the model. + for children in model.children(): + self.offload_to_disk(children) + break + else: + model.to("meta") + + def offload(self): + # offload / onload / preparing -> offload + if self.state != 0: + if self.disk_offload: + self.offload_to_disk(self.module) + else: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + # offload / onload / preparing -> onload + if self.state < 1: + if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": + self.load_from_disk(self.onload_dtype, self.onload_device) + elif self.onload_device != "disk": + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def preparing(self): + # onload / preparing -> preparing + if self.state != 2: + if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": + self.load_from_disk(self.preparing_dtype, self.preparing_device) + elif self.preparing_device != "disk": + self.to(dtype=self.preparing_dtype, device=self.preparing_device) + self.state = 2 + + def cast_to(self, module, dtype, device): + return copy.deepcopy(module).to(dtype=dtype, device=device) + + def computation(self): + # onload / preparing -> computation (temporary) + if self.state == 2: + torch_dtype, device = self.preparing_dtype, self.preparing_device + else: + torch_dtype, device = self.onload_dtype, self.onload_device + if torch_dtype == self.computation_dtype and device == self.computation_device: + module = self.module + elif self.disk_offload and device == "disk": + module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True) + else: + module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device) + return module + + def forward(self, *args, **kwargs): + if self.state == 1 and (self.vram_limit is None or self.check_free_vram()): + self.preparing() + module = self.computation() + return module(*args, **kwargs) + + def __getattr__(self, name): + if name in self.__dict__ or name == "module": + return super().__getattr__(name) + else: + return getattr(self.module, name) + + +class AutoWrappedNonRecurseModule(AutoWrappedModule): + + def __init__( + self, + module: torch.nn.Module, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + super().__init__( + module, + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + name, + disk_map, + **kwargs + ) + if self.disk_offload: + self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)] + + def load_from_disk(self, torch_dtype, device, copy_module=False): + if copy_module: + module = copy.deepcopy(self.module) + else: + module = self.module + state_dict = {} + for name in self.required_params: + param = self.disk_map[self.param_name(name)] + param = param.to(dtype=torch_dtype, device=device) + state_dict[name] = param + module.load_state_dict(state_dict, assign=True, strict=False) + return module + + def offload_to_disk(self, model: torch.nn.Module): + for name in self.required_params: + getattr(self, name).to("meta") + + def cast_to(self, module, dtype, device): + # Parameter casting is implemented in the model architecture. + return module + + def __getattr__(self, name): + if name in self.__dict__ or name == "module": + return super().__getattr__(name) + else: + return getattr(self.module, name) + + +class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): + def __init__( + self, + module: torch.nn.Linear, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + with skip_model_initialization(): + super().__init__( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + ) + self.set_dtype_and_device( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.weight = module.weight + self.bias = module.bias + self.state = 0 + self.name = name + self.lora_A_weights = [] + self.lora_B_weights = [] + self.lora_merger = None + self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + + if offload_dtype == "disk": + self.disk_map = disk_map + self.disk_offload = True + else: + self.disk_offload = False + + def fp8_linear( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + ) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + + x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values + fp8_max = 448.0 + # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. + # To avoid overflow and ensure numerical compatibility during FP8 computation, + # we scale down the input by 2.0 in advance. + # This scaling will be compensated later during the final result scaling. + if self.computation_dtype == torch.float8_e4m3fnuz: + fp8_max = fp8_max / 2.0 + scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) + scale_b = torch.ones((weight.shape[0], 1)).to(device=device) + input = input / (scale_a + 1e-8) + input = input.to(self.computation_dtype) + weight = weight.to(self.computation_dtype) + bias = bias.to(torch.bfloat16) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result + + def load_from_disk(self, torch_dtype, device, assign=True): + weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device) + bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device) + if assign: + state_dict = {"weight": weight} + if bias is not None: state_dict["bias"] = bias + self.load_state_dict(state_dict, assign=True) + return weight, bias + + def offload(self): + # offload / onload / preparing -> offload + if self.state != 0: + if self.disk_offload: + self.to("meta") + else: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + # offload / onload / preparing -> onload + if self.state < 1: + if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": + self.load_from_disk(self.onload_dtype, self.onload_device) + elif self.onload_device != "disk": + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def preparing(self): + # onload / preparing -> preparing + if self.state != 2: + if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": + self.load_from_disk(self.preparing_dtype, self.preparing_device) + elif self.preparing_device != "disk": + self.to(dtype=self.preparing_dtype, device=self.preparing_device) + self.state = 2 + + def computation(self): + # onload / preparing -> computation (temporary) + if self.state == 2: + torch_dtype, device = self.preparing_dtype, self.preparing_device + else: + torch_dtype, device = self.onload_dtype, self.onload_device + if torch_dtype == self.computation_dtype and device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.disk_offload and device == "disk": + weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False) + else: + weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device) + return weight, bias + + def linear_forward(self, x, weight, bias): + if self.enable_fp8: + out = self.fp8_linear(x, weight, bias) + else: + out = torch.nn.functional.linear(x, weight, bias) + return out + + def lora_forward(self, x, out): + if self.lora_merger is None: + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out = out + x @ lora_A.T @ lora_B.T + else: + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + lora_output.append(x @ lora_A.T @ lora_B.T) + lora_output = torch.stack(lora_output) + out = self.lora_merger(out, lora_output) + return out + + def forward(self, x, *args, **kwargs): + if self.state == 1 and (self.vram_limit is None or self.check_free_vram()): + self.preparing() + weight, bias = self.computation() + out = self.linear_forward(x, weight, bias) + if len(self.lora_A_weights) > 0: + out = self.lora_forward(x, out) + return out + + +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs): + if isinstance(model, AutoWrappedNonRecurseModule): + model = model.module + for name, module in model.named_children(): + layer_name = name if name_prefix == "" else name_prefix + "." + name + for source_module, target_module in module_map.items(): + if isinstance(module, source_module): + module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs) + if isinstance(module_, AutoWrappedNonRecurseModule): + enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs) + setattr(model, name, module_) + break + else: + enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs) + + +def fill_vram_config(model, vram_config): + vram_config_ = vram_config.copy() + vram_config_["onload_dtype"] = vram_config["computation_dtype"] + vram_config_["onload_device"] = vram_config["computation_device"] + vram_config_["preparing_dtype"] = vram_config["computation_dtype"] + vram_config_["preparing_device"] = vram_config["computation_device"] + for k in vram_config: + if vram_config[k] != vram_config_[k]: + print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}") + break + return vram_config_ + + +def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs): + for source_module, target_module in module_map.items(): + # If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly. + if isinstance(model, source_module): + vram_config = fill_vram_config(model, vram_config) + model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs) + break + else: + enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs) + # `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled. + model.vram_management_enabled = True + return model diff --git a/diffsynth/datasets/mvdataset.py b/diffsynth/datasets/mvdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..148a454571f671890e8e7fcc4032cc3949b8f4d5 --- /dev/null +++ b/diffsynth/datasets/mvdataset.py @@ -0,0 +1,393 @@ +import imageio, os, torch, warnings, torchvision, argparse, json +from peft import LoraConfig, inject_adapter_in_model +from PIL import Image +import pandas as pd +from tqdm import tqdm +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs +import random +from decord import VideoReader +from decord import cpu, gpu +import imageio.v3 as iio + +from torchvision import transforms +import torchvision +import random +import decord +from torchvision import transforms +import re +decord.bridge.set_bridge('torch') +import random +import numpy as np +from PIL import Image, ImageOps + +class MulltiShot_MultiView_Dataset(torch.utils.data.Dataset): + def __init__(self, dataset_base_path='/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/merged_mark_paishe_ds_meg_merge_dwposefilter_paishe.json', + ref_image_path='/root/paddlejob/workspace/qizipeng/code/longvideogen/output.json', + time_division_factor=4, + time_division_remainder=1, + max_pixels=1920*1080, + height_division_factor=16, width_division_factor=16, + transform=None, + length=None, + resolution=None, + prev_length=5, + ref_num = 3, + training = True): + self.data_path = dataset_base_path + self.data = [] + self.length = length + self.resolution = resolution + self.height, self.width = resolution + self.num_frames = length + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + self.max_pixels = max_pixels + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.prev_length = prev_length + self.training = training + self.ref_num = ref_num + + with open(self.data_path, 'r') as f: + meta_datas = json.load(f) + + for video_path in tqdm(meta_datas.keys()): + context = meta_datas[video_path] + candidate_labels = list(context.keys()) + candidate_labels.remove('text') + + disk_path = meta_datas[video_path]["disk_path"] + if not disk_path.lower().endswith(".mp4"): + continue + + + # reader = imageio.get_reader(meta_datas[video_path]["disk_path"]) + # total_original_frames = reader.count_frames() + # total_frame = total_original_frames # context["end_index"] - context["start_index"] - 1 + total_frame = None + ref_id = self.get_ref_id(face_crop_angle = context['facedetect_v1'], facedetect_v1_frame_index = context['facedetect_v1_frame_index'], total_frame = total_frame) + if ref_id == []: + continue + ref_id_all = [] + for ids in ref_id: + ref_id_grop = [] + for id in ids: + coordinate = context['facedetect_v1'][id][0]['detect'] + if context['facedetect_v1'][id][0]['detect']["prob"] < 0.99: + continue + top, height, width, left = coordinate['top'], coordinate['height'], coordinate['width'], coordinate['left'] + if not(min(height, width) > 80 ): + continue + # enlarge bbox 1.5x + width = int(width * 1) + height = int(height * 1) + frame_index = context['facedetect_v1_frame_index'][id] + ref_id_grop.append([top, height, width, left, int(frame_index)]) + if ref_id_grop != []: + if len(ref_id_grop) >= 3: #self.ref_num: ### 为了和ref_num = 3 保持数据一致 + ref_id_all.append(ref_id_grop) + if ref_id_all == []: + continue + meta_prompt = {} + meta_prompt["global_caption"] = None + meta_prompt["per_shot_prompt"] = [] + meta_prompt["single_prompt"] = context['text'] + self.data.append({'video_path': disk_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all}) + # self.data.append({'video_path':video_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all}) + + random.seed(42) # 让每次划分一致(可选) + total = len(self.data) + test_count = max(1, int(total * 0.05)) # 至少一个 + + # 随机选择 test 的 index + test_indices = set(random.sample(range(total), test_count)) + + self.data_test = [self.data[i] for i in range(total) if i in test_indices] + self.data_train = [self.data[i] for i in range(total) if i not in test_indices] + print(f"🔥 数据集划分完成:Train={len(self.data_train)}, Test={len(self.data_test)}") + + if self.height is not None and self.width is not None: + print("Height and width are fixed. Setting `dynamic_resolution` to False.") + self.dynamic_resolution = False + elif self.height is None and self.width is None: + print("Height and width are none. Setting `dynamic_resolution` to True.") + self.dynamic_resolution = True + + def get_ref_id(self, face_crop_angle, facedetect_v1_frame_index = None, total_frame = None, angle_threshold=50): + """ + 返回满足角度差异要求的三元组 [i, j, k] + 要求: + - face_crop_angle[i] / [j] / [k] 都必须非空 + - i,j 两者任意 yaw/pitch/roll 差值 > angle_threshold + - k != i != j,且 k 也必须非空 + """ + ref_id = [] + max_try = 5 + need_max = 3 + try_num = 0 + + # 过滤空元素,保留有效索引 + valid_indices = [idx for idx, item in enumerate(face_crop_angle) if item] + N = len(valid_indices) + + if N < 3: + return ref_id # 不足 3 张有效图,无法组成三元组 + + # 两两组合检查角度差 + for a in range(N - 1): + i = valid_indices[a] + # if facedetect_v1_frame_index[i] > total_frame: + # continue + angle_i = face_crop_angle[i][0]["angle"] + + for b in range(a + 1, N): + j = valid_indices[b] + # if facedetect_v1_frame_index[j] > total_frame: + # continue + angle_j = face_crop_angle[j][0]["angle"] + + # 判断是否满足阈值 + if ( + abs(angle_i["yaw"] - angle_j["yaw"]) > angle_threshold or + abs(angle_i["pitch"] - angle_j["pitch"]) > angle_threshold or + abs(angle_i["roll"] - angle_j["roll"]) > angle_threshold + ): + # 找第三个 k + for c in range(N): + k = valid_indices[c] + # if facedetect_v1_frame_index[k] > total_frame: + # continue + if k != i and k != j: + ref_id.append([i, j, k]) + break + + try_num += 1 + if try_num >= max_try or len(ref_id) >= need_max: + return ref_id + + return ref_id + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def get_height_width(self, image): + if self.dynamic_resolution: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + # def + # img_ratio = img.width / img.height + # target_ratio = w / h + # if img_ratio > target_ratio: # Image is wider than target + # new_width = w + # new_height = int(new_width / img_ratio) + # else: # Image is taller than target + # new_height = h + # new_width = int(new_height * img_ratio) + + # # img = img.resize((new_width, new_height), Image.ANTIALIAS) + # img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # # Create a new image with the target size and place the resized image in the center + # delta_w = w - img.size[0] + # delta_h = h - img.size[1] + # padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + # new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) + + def resize_ref(self, img, target_h, target_w): + h = target_h + w = target_w + img = img.convert("RGB") + # Calculate the required size to keep aspect ratio and fill the rest with padding. + img_ratio = img.width / img.height + target_ratio = w / h + + if img_ratio > target_ratio: # Image is wider than target + new_width = w + new_height = int(new_width / img_ratio) + else: # Image is taller than target + new_height = h + new_width = int(new_height * img_ratio) + + # img = img.resize((new_width, new_height), Image.ANTIALIAS) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Create a new image with the target size and place the resized image in the center + delta_w = w - img.size[0] + delta_h = h - img.size[1] + padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) + + return new_img + + + def load_video_crop_ref_image(self, video_path=None, ref_id_all=[[]]): + ### fps 转化 + reader = imageio.get_reader(video_path) + meta = reader.get_meta_data() + original_fps = meta.get("fps", 24) + target_fps = 16 + duration_seconds = 5 + target_frames = target_fps * duration_seconds + 1 # = 80 frames + + # ---- 获取原视频帧数 ---- + try: + total_original_frames = reader.count_frames() + except: + total_original_frames = int(meta.get("duration", 5) * original_fps) + + + + # ---- 需要多少原始帧(5秒)---- + need_orig_frames = int(original_fps * duration_seconds) + + # ---- Case 1: 原视频 >= 5秒 → 随机选择 5 秒起点 ---- + if total_original_frames > need_orig_frames: + max_start = total_original_frames - need_orig_frames + start_frame = random.randint(0, max_start) + segment_start = start_frame + segment_end = start_frame + need_orig_frames + else: + # ---- Case 2: 原视频 < 5秒 → 用全部帧 ---- + segment_start = 0 + segment_end = total_original_frames + + # ---- 均匀采样 80 帧 ---- + sample_ids = np.linspace(segment_start, segment_end - 1, num=target_frames, dtype=int) + + frames = [] + for frame_id in sample_ids: + frame = reader.get_data(int(frame_id)) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame, *self.get_height_width(frame)) + frames.append(frame) + + # =========================== + # 选择参考图部分(你要求的) + # =========================== + + # 1)从 ref_images_all(三维 list)里随机选一组 + # ref_images_all = [ [img1, img2, img3], [imgA, imgB, imgC], ... ] + ref_group = random.choice(ref_id_all) + + # 2)检查资源是否足够 + if len(ref_group) < self.ref_num: + raise ValueError(f"需要 {self.ref_num} 张参考图,但该组只有 {len(ref_group)} 张。") + + # 3)从该组中随机选 self.ref_num 张 + selected_refs = random.sample(ref_group, self.ref_num) + random.shuffle(selected_refs) + + ref_images = [] + for sf in selected_refs: + top, height, width, left, frame_index = sf + # import pdb; pdb.set_trace() + if frame_index > total_original_frames: + print(f"{video_path}, frame_index({frame_index}) out of range") + frame = reader.get_data(int(frame_index)) + frame = Image.fromarray(frame) + xmin, ymin, xmax, ymax = left, top, left + width, top + height + cropped_image = frame.crop((xmin, ymin, xmax, ymax)).convert("RGB") + cropped_image = self.resize_ref(cropped_image, self.height, self.width) + # Calculate the required size to keep aspect ratio and fill the rest with padding. + ref_images.append(cropped_image) + reader.close() + + return frames, ref_images + + def __getitem__(self, index): + max_retry = 10 # 最多重试 10 次,避免死循环 + retry = 0 + + while retry < max_retry: + # ----- 选择 train / test 数据 ----- + if self.training: + meta_data = self.data_train[index % len(self.data_train)] + else: + meta_data = self.data_test[index % len(self.data_test)] + + video_path = meta_data['video_path'] + meta_prompt = meta_data['meta_prompt'] + ref_id_all = meta_data['ref_id_all'] + + # ----- 尝试读取 video + ref ----- + try: + input_video, ref_images = self.load_video_crop_ref_image( + video_path=video_path, + ref_id_all=ref_id_all + ) + except Exception as e: + print("❌ Exception in load_video_crop_ref_image") + print(f" video_path: {video_path}") + print(f" error type: {type(e).__name__}") + print(f" error msg : {e}") + + # 打印 traceback,定位问题更容易 + import traceback + traceback.print_exc() + input_video = None + ref_images = None + # ----- 如果成功,并且 video 不为空,返回结果 ----- + if input_video is not None and len(input_video) > 0: + return { + "global_caption": None, + "shot_num": 1, + "pre_shot_caption": [], + "single_caption": meta_prompt["single_prompt"], + "video": input_video, + "ref_num": self.ref_num, + "ref_images": ref_images, + "video_path": video_path + } + + # ----- 如果失败,换 index,并继续尝试 ----- + retry += 1 + index = random.randint(0, len(self.data_train) - 1 if self.training else len(self.data_test) - 1) + + # 若 10 次都失败,返回最后一次的错误内容 + raise RuntimeError(f"❌ [Dataset] Failed to load video/ref after {max_retry} retries.") + + def __len__(self): + if self.training: + return len(self.data_train) + else: + return len(self.data_test) + +if __name__ == '__main__': + from torch.utils.data import DataLoader + dataset = MulltiShot_MultiView_Dataset(length=49, resolution=(384, 640), training=True) + print(len(dataset)) + metadata = dataset[0] + # results = dataset[0] + # loader = DataLoader( + # dataset, + # batch_size=1, # 视频一般 batch=1 + # shuffle=False, # 你想打乱就 True + # num_workers=10, # ⭐ 重点:开启 8 个子进程并行加载 + # pin_memory=True, + # prefetch_factor=2, # 每个 worker 预读取 2 个样本 + # collate_fn=lambda x: x[0], # ⭐ 不做任何 collate + # ) + + # for batch in tqdm(loader): + # pass + for i in tqdm(range(len(dataset))): + file = dataset[i] + + assert 0 + \ No newline at end of file diff --git a/diffsynth/diffusion/__init__.py b/diffsynth/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4a0873a7b3d09e95aa00cfe340d653c58a834b --- /dev/null +++ b/diffsynth/diffusion/__init__.py @@ -0,0 +1,6 @@ +from .flow_match import FlowMatchScheduler +from .training_module import DiffusionTrainingModule +from .logger import ModelLogger +from .runner import launch_training_task, launch_data_process_task +from .parsers import * +from .loss import * diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2bec693c5c57e0890613ba6544e453b5fc28c9e4 --- /dev/null +++ b/diffsynth/diffusion/base_pipeline.py @@ -0,0 +1,439 @@ +from PIL import Image +import torch +import numpy as np +from einops import repeat, reduce +from typing import Union +from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig +from ..utils.lora import GeneralLoRALoader +from ..models.model_loader import ModelPool +from ..utils.controlnet import ControlNetInput + + +class PipelineUnit: + def __init__( + self, + seperate_cfg: bool = False, + take_over: bool = False, + input_params: tuple[str] = None, + output_params: tuple[str] = None, + input_params_posi: dict[str, str] = None, + input_params_nega: dict[str, str] = None, + onload_model_names: tuple[str] = None + ): + self.seperate_cfg = seperate_cfg + self.take_over = take_over + self.input_params = input_params + self.output_params = output_params + self.input_params_posi = input_params_posi + self.input_params_nega = input_params_nega + self.onload_model_names = onload_model_names + + def fetch_input_params(self): + params = [] + if self.input_params is not None: + for param in self.input_params: + params.append(param) + if self.input_params_posi is not None: + for _, param in self.input_params_posi.items(): + params.append(param) + if self.input_params_nega is not None: + for _, param in self.input_params_nega.items(): + params.append(param) + params = sorted(list(set(params))) + return params + + def fetch_output_params(self): + params = [] + if self.output_params is not None: + for param in self.output_params: + params.append(param) + return params + + def process(self, pipe, **kwargs) -> dict: + return {} + + def post_process(self, pipe, **kwargs) -> dict: + return {} + + +class BasePipeline(torch.nn.Module): + + def __init__( + self, + device="cuda", torch_dtype=torch.float16, + height_division_factor=64, width_division_factor=64, + time_division_factor=None, time_division_remainder=None, + ): + super().__init__() + # The device and torch_dtype is used for the storage of intermediate variables, not models. + self.device = device + self.torch_dtype = torch_dtype + # The following parameters are used for shape check. + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # VRAM management + self.vram_management_enabled = False + # Pipeline Unit Runner + self.unit_runner = PipelineUnitRunner() + # LoRA Loader + self.lora_loader = GeneralLoRALoader + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def check_resize_height_width(self, height, width, num_frames=None): + # Shape check + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if num_frames is None: + return height, width + else: + if num_frames % self.time_division_factor != self.time_division_remainder: + num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + return height, width, num_frames + + + def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): + # Transform a PIL.Image to torch.Tensor + image = torch.Tensor(np.array(image, dtype=np.float32)) + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + image = image * ((max_value - min_value) / 255) + min_value + image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) + return image + + + def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a list of PIL.Image to torch.Tensor + video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] + video = torch.stack(video, dim=pattern.index("T") // 2) + return video + + + def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to PIL.Image + if pattern != "H W C": + vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") + image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) + image = image.to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(image.numpy()) + return image + + + def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to list of PIL.Image + if pattern != "T H W C": + vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") + video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] + return video + + + def load_models_to_device(self, model_names): + if self.vram_management_enabled: + # offload models + for name, model in self.named_children(): + if name not in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + if hasattr(model, "offload"): + model.offload() + else: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + torch.cuda.empty_cache() + # onload models + for name, model in self.named_children(): + if name in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + if hasattr(model, "onload"): + model.onload() + else: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + + + def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): + # Initialize Gaussian noise + generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) + noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + return noise + + + def get_vram(self): + return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) + + def get_module(self, model, name): + if "." in name: + name, suffix = name[:name.index(".")], name[name.index(".") + 1:] + if name.isdigit(): + return self.get_module(model[int(name)], suffix) + else: + return self.get_module(getattr(model, name), suffix) + else: + return getattr(model, name) + + def freeze_except(self, model_names): + self.eval() + self.requires_grad_(False) + for name in model_names: + module = self.get_module(self, name) + if module is None: + print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.") + continue + module.train() + module.requires_grad_(True) + + + def blend_with_mask(self, base, addition, mask): + return base * (1 - mask) + addition * mask + + + def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs): + timestep = scheduler.timesteps[progress_id] + if inpaint_mask is not None: + noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents) + noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask) + latents_next = scheduler.step(noise_pred, timestep, latents) + return latents_next + + + def split_pipeline_units(self, model_names: list[str]): + return PipelineUnitGraph().split_pipeline_units(self.units, model_names) + + + def flush_vram_management_device(self, device): + for module in self.modules(): + if isinstance(module, AutoTorchModule): + module.offload_device = device + module.onload_device = device + module.preparing_device = device + module.computation_device = device + + + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str] = None, + alpha=1, + hotload=None, + state_dict=None, + ): + if state_dict is None: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + else: + lora = state_dict + lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + lora = lora_loader.convert_state_dict(lora) + if hotload is None: + hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled") + if hotload: + if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")): + raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.") + updated_num = 0 + for _, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + name = module.name + lora_a_name = f'{name}.lora_A.weight' + lora_b_name = f'{name}.lora_B.weight' + if lora_a_name in lora and lora_b_name in lora: + updated_num += 1 + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") + else: + lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) + + + def clear_lora(self): + cleared_num = 0 + for name, module in self.named_modules(): + if isinstance(module, AutoWrappedLinear): + if hasattr(module, "lora_A_weights"): + if len(module.lora_A_weights) > 0: + cleared_num += 1 + module.lora_A_weights.clear() + if hasattr(module, "lora_B_weights"): + module.lora_B_weights.clear() + print(f"{cleared_num} LoRA layers are cleared.") + + + def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): + model_pool = ModelPool() + for model_config in model_configs: + model_config.download_if_necessary() + vram_config = model_config.vram_config() + vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype + vram_config["computation_device"] = vram_config["computation_device"] or self.device + model_pool.auto_load_model( + model_config.path, + vram_config=vram_config, + vram_limit=vram_limit, + clear_parameters=model_config.clear_parameters, + ) + return model_pool + + + def check_vram_management_state(self): + vram_management_enabled = False + for module in self.children(): + if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"): + vram_management_enabled = True + return vram_management_enabled + + + def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if cfg_scale != 1.0: + noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + return noise_pred + + +class PipelineUnitGraph: + def __init__(self): + pass + + def build_edges(self, units: list[PipelineUnit]): + # Establish dependencies between units + # to search for subsequent related computation units. + last_compute_unit_id = {} + edges = [] + for unit_id, unit in enumerate(units): + for input_param in unit.fetch_input_params(): + if input_param in last_compute_unit_id: + edges.append((last_compute_unit_id[input_param], unit_id)) + for output_param in unit.fetch_output_params(): + last_compute_unit_id[output_param] = unit_id + return edges + + def build_chains(self, units: list[PipelineUnit]): + # Establish updating chains for each variable + # to track their computation process. + params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], []) + params = sorted(list(set(params))) + chains = {param: [] for param in params} + for unit_id, unit in enumerate(units): + for param in unit.fetch_output_params(): + chains[param].append(unit_id) + return chains + + def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]): + # Search for units that directly participate in the model's computation. + related_unit_ids = [] + for unit_id, unit in enumerate(units): + for model_name in model_names: + if unit.onload_model_names is not None and model_name in unit.onload_model_names: + related_unit_ids.append(unit_id) + break + return related_unit_ids + + def search_related_unit_ids(self, edges, start_unit_ids, direction="target"): + # Search for subsequent related computation units. + related_unit_ids = [unit_id for unit_id in start_unit_ids] + while True: + neighbors = [] + for source, target in edges: + if direction == "target" and source in related_unit_ids and target not in related_unit_ids: + neighbors.append(target) + elif direction == "source" and source not in related_unit_ids and target in related_unit_ids: + neighbors.append(source) + neighbors = sorted(list(set(neighbors))) + if len(neighbors) == 0: + break + else: + related_unit_ids.extend(neighbors) + related_unit_ids = sorted(list(set(related_unit_ids))) + return related_unit_ids + + def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids): + # If the input parameters of this subgraph are updated outside the subgraph, + # search for the units where these updates occur. + first_compute_unit_id = {} + for unit_id in related_unit_ids: + for param in units[unit_id].fetch_input_params(): + if param not in first_compute_unit_id: + first_compute_unit_id[param] = unit_id + updating_unit_ids = [] + for param in first_compute_unit_id: + unit_id = first_compute_unit_id[param] + chain = chains[param] + if unit_id in chain and chain.index(unit_id) != len(chain) - 1: + for unit_id_ in chain[chain.index(unit_id) + 1:]: + if unit_id_ not in related_unit_ids: + updating_unit_ids.append(unit_id_) + related_unit_ids.extend(updating_unit_ids) + related_unit_ids = sorted(list(set(related_unit_ids))) + return related_unit_ids + + def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]): + # Split the computation graph, + # separating all model-related computations. + related_unit_ids = self.search_direct_unit_ids(units, model_names) + edges = self.build_edges(units) + chains = self.build_chains(units) + while True: + num_related_unit_ids = len(related_unit_ids) + related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target") + related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids) + if len(related_unit_ids) == num_related_unit_ids: + break + else: + num_related_unit_ids = len(related_unit_ids) + related_units = [units[i] for i in related_unit_ids] + unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids] + return related_units, unrelated_units + + +class PipelineUnitRunner: + def __init__(self): + pass + + def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: + # Let the pipeline unit take over this function. + inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) + elif unit.seperate_cfg: + # Positive side + processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_posi.update(processor_outputs) + # Negative side + if inputs_shared["cfg_scale"] != 1: + processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_nega.update(processor_outputs) + else: + inputs_nega.update(processor_outputs) + else: + processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_shared.update(processor_outputs) + return inputs_shared, inputs_posi, inputs_nega diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5fbc52855d730354daa7f92d1a0ed3715f670b --- /dev/null +++ b/diffsynth/diffusion/flow_match.py @@ -0,0 +1,179 @@ +import torch, math +from typing_extensions import Literal + + +class FlowMatchScheduler(): + + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): + self.set_timesteps_fn = { + "FLUX.1": FlowMatchScheduler.set_timesteps_flux, + "Wan": FlowMatchScheduler.set_timesteps_wan, + "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, + "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, + "Z-Image": FlowMatchScheduler.set_timesteps_z_image, + }.get(template, FlowMatchScheduler.set_timesteps_flux) + self.num_train_timesteps = 1000 + + @staticmethod + def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.003/1.002 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 5 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + @staticmethod + def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): + sigma_min = 0.0 + sigma_max = 1.0 + num_train_timesteps = 1000 + shift_terminal = 0.02 + # Sigmas + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + # Mu + if exponential_shift_mu is not None: + mu = exponential_shift_mu + elif dynamic_shift_len is not None: + mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len) + else: + mu = 0.8 + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1 - sigmas + scale_factor = one_minus_z[-1] / (1 - shift_terminal) + sigmas = 1 - (one_minus_z / scale_factor) + # Timesteps + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def compute_empirical_mu(image_seq_len, num_steps): + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + @staticmethod + def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16): + sigma_min = 1 / num_inference_steps + sigma_max = 1.0 + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + if target_timesteps is not None: + target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device) + for timestep in target_timesteps: + timestep_id = torch.argmin((timesteps - timestep).abs()) + timesteps[timestep_id] = timestep + return sigmas, timesteps + + def set_training_weight(self): + steps = 1000 + x = self.timesteps + y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) + if len(self.timesteps) != 1000: + # This is an empirical formula. + bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) + bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] + self.linear_timesteps_weights = bsmntw_weighing + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): + self.sigmas, self.timesteps = self.set_timesteps_fn( + num_inference_steps=num_inference_steps, + denoising_strength=denoising_strength, + **kwargs, + ) + if training: + self.set_training_weight() + self.training = True + else: + self.training = False + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ff51e2c0ad0a28f0bbbb8ce12603617dc769afc2 --- /dev/null +++ b/diffsynth/diffusion/logger.py @@ -0,0 +1,43 @@ +import os, torch +from accelerate import Accelerator + + +class ModelLogger: + def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x): + self.output_path = output_path + self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + self.state_dict_converter = state_dict_converter + self.num_steps = 0 + + + def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) + accelerator.save(state_dict, path, safe_serialization=True) diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ae44bb6831eac932933fef812c4ac1b31decebbe --- /dev/null +++ b/diffsynth/diffusion/loss.py @@ -0,0 +1,119 @@ +from .base_pipeline import BasePipeline +import torch + + +def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) + + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + + noise = torch.randn_like(inputs["input_latents"]) + inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * pipe.scheduler.training_weight(timestep) + return loss + + +def DirectDistillLoss(pipe: BasePipeline, **inputs): + pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) + pipe.scheduler.training = True + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) + inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) + loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) + return loss + + +class TrajectoryImitationLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.initialized = False + + def initialize(self, device): + import lpips # TODO: remove it + self.loss_fn = lpips.LPIPS(net='alex').to(device) + self.initialized = True + + def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + trajectory = [inputs_shared["latents"].clone()] + + pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + trajectory.append(inputs_shared["latents"].clone()) + return pipe.scheduler.timesteps, trajectory + + def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + loss = 0 + pipe.scheduler.set_timesteps(num_inference_steps, training=True) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + + progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs()) + inputs_shared["latents"] = trajectory_teacher[progress_id_teacher] + + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + + sigma = pipe.scheduler.sigmas[progress_id] + sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1] + if progress_id + 1 >= len(pipe.scheduler.timesteps): + latents_ = trajectory_teacher[-1] + else: + progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) + latents_ = trajectory_teacher[progress_id_teacher] + + target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) + loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) + return loss + + def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + inputs_shared["latents"] = trajectory_teacher[0] + pipe.scheduler.set_timesteps(num_inference_steps) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + image_pred = pipe.vae_decoder(inputs_shared["latents"]) + image_real = pipe.vae_decoder(trajectory_teacher[-1]) + loss = self.loss_fn(image_pred.float(), image_real.float()) + return loss + + def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega): + if not self.initialized: + self.initialize(pipe.device) + with torch.no_grad(): + pipe.scheduler.set_timesteps(8) + timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2) + timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device) + loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss = loss_1 + loss_2 + return loss diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c6c6afdd3868aab45b1b51a3a47fe2e37d77f4 --- /dev/null +++ b/diffsynth/diffusion/parsers.py @@ -0,0 +1,70 @@ +import argparse + + +def add_dataset_base_config(parser: argparse.ArgumentParser): + parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") + parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") + return parser + +def add_image_size_config(parser: argparse.ArgumentParser): + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.") + return parser + +def add_video_size_config(parser: argparse.ArgumentParser): + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") + return parser + +def add_model_config(parser: argparse.ArgumentParser): + parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") + parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") + parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") + parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.") + parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.") + return parser + +def add_training_config(parser: argparse.ArgumentParser): + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") + parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") + return parser + +def add_output_config(parser: argparse.ArgumentParser): + parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") + parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + return parser + +def add_lora_config(parser: argparse.ArgumentParser): + parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") + parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") + parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.") + parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.") + parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.") + return parser + +def add_gradient_config(parser: argparse.ArgumentParser): + parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") + parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + return parser + +def add_general_config(parser: argparse.ArgumentParser): + parser = add_dataset_base_config(parser) + parser = add_model_config(parser) + parser = add_training_config(parser) + parser = add_output_config(parser) + parser = add_lora_config(parser) + parser = add_gradient_config(parser) + return parser diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd4c23e188eed01486b740fee032a3fde3d50e7 --- /dev/null +++ b/diffsynth/diffusion/runner.py @@ -0,0 +1,129 @@ +import os, torch +import wandb # 新增 +from tqdm import tqdm +from accelerate import Accelerator +from .training_module import DiffusionTrainingModule +from .logger import ModelLogger + + +def launch_training_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, + num_workers: int = 1, + save_steps: int = None, + num_epochs: int = 1, + args = None, +): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + collate_fn=lambda x: x[0], + num_workers=num_workers, + ) + + model, optimizer, dataloader, scheduler = accelerator.prepare( + model, optimizer, dataloader, scheduler + ) + + global_step = 0 # 用于 wandb 记录全局 step + + for epoch_id in range(num_epochs): + # 只在本地主进程显示 tqdm,避免多卡重复进度条 + pbar = tqdm( + dataloader, + disable=not accelerator.is_local_main_process, + desc=f"Epoch {epoch_id}", + ) + for data in pbar: + with accelerator.accumulate(model): + optimizer.zero_grad() + if dataset.load_from_cache: + loss = model({}, inputs=data) + else: + loss = model(data) + accelerator.backward(loss) + optimizer.step() + model_logger.on_step_end(accelerator, model, save_steps) + scheduler.step() + + global_step += 1 + + # ============= wandb logging(只在主进程) ============= + if ( + args is not None + and hasattr(args, "wandb_mode") + and args.wandb_mode != "disabled" + and accelerator.is_main_process + ): + log_every = getattr(args, "wandb_log_every", 10) + if global_step % log_every == 0: + # 这里直接用当前进程的 loss 就够了 + loss_value = loss.detach().float().item() + try: + lr = scheduler.get_last_lr()[0] + except Exception: + lr = optimizer.param_groups[0]["lr"] + + wandb.log( + { + "train/loss": loss_value, + "train/lr": lr, + "train/epoch": epoch_id, + "train/step": global_step, + } + ) + # ======================================================= + + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) + + +def launch_data_process_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + num_workers: int = 8, + args = None, +): + if args is not None: + num_workers = args.dataset_num_workers + + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=num_workers, + ) + model, dataloader = accelerator.prepare(model, dataloader) + + for data_id, data in enumerate(tqdm( + dataloader, + disable=not accelerator.is_local_main_process, + desc="Data process", + )): + with accelerator.accumulate(model): + with torch.no_grad(): + folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) + os.makedirs(folder, exist_ok=True) + save_path = os.path.join( + model_logger.output_path, + str(accelerator.process_index), + f"{data_id}.pth", + ) + data = model(data) + torch.save(data, save_path) diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b33291e873aeb340c79576fb07454ec4a502e6 --- /dev/null +++ b/diffsynth/diffusion/training_module.py @@ -0,0 +1,212 @@ +import torch, json +from ..core import ModelConfig, load_state_dict +from ..utils.controlnet import ControlNetInput +from peft import LoraConfig, inject_adapter_in_model + + +class DiffusionTrainingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + def to(self, *args, **kwargs): + for name, model in self.named_children(): + model.to(*args, **kwargs) + return self + + + def trainable_modules(self): + trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) + return trainable_modules + + + def trainable_param_names(self): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + return trainable_param_names + + + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None): + if lora_alpha is None: + lora_alpha = lora_rank + if isinstance(target_modules, list) and len(target_modules) == 1: + target_modules = target_modules[0] + lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) + model = inject_adapter_in_model(lora_config, model) + if upcast_dtype is not None: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(upcast_dtype) + return model + + + def mapping_lora_state_dict(self, state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if "lora_A.weight" in key or "lora_B.weight" in key: + new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight") + new_state_dict[new_key] = value + elif "lora_A.default.weight" in key or "lora_B.default.weight" in key: + new_state_dict[key] = value + return new_state_dict + + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + trainable_param_names = self.trainable_param_names() + state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} + if remove_prefix is not None: + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith(remove_prefix): + name = name[len(remove_prefix):] + state_dict_[name] = param + state_dict = state_dict_ + return state_dict + + + def transfer_data_to_device(self, data, device, torch_float_dtype=None): + if data is None: + return data + elif isinstance(data, torch.Tensor): + data = data.to(device) + if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]: + data = data.to(torch_float_dtype) + return data + elif isinstance(data, tuple): + data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data) + return data + elif isinstance(data, list): + data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data) + return data + elif isinstance(data, dict): + data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data} + return data + else: + return data + + def parse_vram_config(self, fp8=False, offload=False, device="cpu"): + if fp8: + return { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": device, + "onload_dtype": torch.float8_e4m3fn, + "onload_device": device, + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + } + elif offload: + return { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + "clear_parameters": True, + } + else: + return {} + + def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"): + fp8_models = [] if fp8_models is None else fp8_models.split(",") + offload_models = [] if offload_models is None else offload_models.split(",") + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + for path in model_paths: + vram_config = self.parse_vram_config( + fp8=path in fp8_models, + offload=path in offload_models, + device=device + ) + model_configs.append(ModelConfig(path=path, **vram_config)) + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + for model_id_with_origin_path in model_id_with_origin_paths: + model_id, origin_file_pattern = model_id_with_origin_path.split(":") + vram_config = self.parse_vram_config( + fp8=model_id_with_origin_path in fp8_models, + offload=model_id_with_origin_path in offload_models, + device=device + ) + model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config)) + return model_configs + + + def switch_pipe_to_training_mode( + self, + pipe, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + task="sft", + ): + # Scheduler + pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Preset LoRA + if preset_lora_path is not None: + pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path) + + # FP8 + # FP8 relies on a model-specific memory management scheme. + # It is delegated to the subclass. + + # Add LoRA to the base models + if lora_base_model is not None and not task.endswith(":data_process"): + if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None: + print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.") + return + model = self.add_lora_to_model( + getattr(pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank, + upcast_dtype=pipe.torch_dtype, + ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(pipe, lora_base_model, model) + + + def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None): + models_require_backward = [] + if trainable_models is not None: + models_require_backward += trainable_models.split(",") + if lora_base_model is not None: + models_require_backward += [lora_base_model] + if task.endswith(":data_process"): + _, pipe.units = pipe.split_pipeline_units(models_require_backward) + elif task.endswith(":train"): + pipe.units, _ = pipe.split_pipeline_units(models_require_backward) + return pipe + + def parse_extra_inputs(self, data, extra_inputs, inputs_shared): + controlnet_keys_map = ( + ("blockwise_controlnet_", "blockwise_controlnet_inputs",), + ("controlnet_", "controlnet_inputs"), + ) + controlnet_inputs = {} + for extra_input in extra_inputs: + for prefix, name in controlnet_keys_map: + if extra_input.startswith(prefix): + if name not in controlnet_inputs: + controlnet_inputs[name] = {} + controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input] + break + else: + inputs_shared[extra_input] = data[extra_input] + for name, params in controlnet_inputs.items(): + inputs_shared[name] = [ControlNetInput(**params)] + return inputs_shared diff --git a/diffsynth/models/comp_attn_model.py b/diffsynth/models/comp_attn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..86baabb0eea81af1142ac2b9566962dacb37c389 --- /dev/null +++ b/diffsynth/models/comp_attn_model.py @@ -0,0 +1,592 @@ +import math +from dataclasses import dataclass +from typing import Optional, Sequence + +import torch +import torch.nn.functional as F + +from ..diffusion.base_pipeline import PipelineUnit + + +@dataclass +class CompAttnConfig: + subjects: Sequence[str] + bboxes: Optional[Sequence] = None + enable_sci: bool = True + enable_lam: bool = True + temperature: float = 0.2 + apply_to_negative: bool = False + interpolate: bool = False + state_texts: Optional[Sequence[Sequence[str]]] = None + state_weights: Optional[Sequence] = None + state_scale: float = 1.0 + state_template: str = "{subject} is {state}" + + +def find_subsequence_indices(prompt_ids: torch.Tensor, subject_ids: torch.Tensor, valid_len: int) -> list[int]: + if subject_ids.numel() == 0 or valid_len <= 0: + return [] + prompt_slice = prompt_ids[:valid_len].tolist() + subject_list = subject_ids.tolist() + span = len(subject_list) + if span > valid_len: + return [] + for start in range(valid_len - span + 1): + if prompt_slice[start:start + span] == subject_list: + return list(range(start, start + span)) + return [] + + +def build_subject_token_mask(indices_list: list[list[int]], seq_len: int) -> torch.Tensor: + mask = torch.zeros((len(indices_list), seq_len), dtype=torch.bool) + for i, indices in enumerate(indices_list): + if not indices: + continue + mask[i, torch.tensor(indices, dtype=torch.long)] = True + return mask + + +def compute_saliency(prompt_vecs: torch.Tensor, anchor_vecs: torch.Tensor, tau: float) -> torch.Tensor: + prompt_norm = prompt_vecs / (prompt_vecs.norm(dim=-1, keepdim=True) + 1e-8) + anchor_norm = anchor_vecs / (anchor_vecs.norm(dim=-1, keepdim=True) + 1e-8) + cosine = torch.matmul(prompt_norm, anchor_norm.transpose(0, 1)) + scores = torch.exp(cosine / tau) + diag = scores.diagonal() + denom = scores.sum(dim=1).clamp(min=1e-8) + return diag / denom + + +def compute_delta(anchor_vecs: torch.Tensor) -> torch.Tensor: + total = anchor_vecs.sum(dim=0, keepdim=True) + return anchor_vecs * anchor_vecs.shape[0] - total + + +_sci_call_count = [0] # 使用列表以便在函数内修改 + +def apply_sci(context: torch.Tensor, state: dict, timestep: torch.Tensor) -> torch.Tensor: + if state is None or not state.get("enable_sci", False): + return context + subject_mask = state.get("subject_token_mask") + delta = state.get("delta") + saliency = state.get("saliency") + if subject_mask is None or delta is None or saliency is None: + return context + if subject_mask.numel() == 0: + return context + t_scale = float(state.get("timestep_scale", 1000.0)) + t_value = float(timestep.reshape(-1)[0].item()) + t_ratio = max(0.0, min(1.0, t_value / t_scale)) + omega = 1.0 - t_ratio + delta = delta.to(device=context.device, dtype=context.dtype) + saliency = saliency.to(device=context.device, dtype=context.dtype) + scale = omega * (1.0 - saliency).unsqueeze(-1) + delta = delta * scale + mask = subject_mask.to(device=context.device) + token_delta = torch.matmul(mask.to(dtype=context.dtype).transpose(0, 1), delta) + apply_mask = state.get("apply_mask") + if apply_mask is not None: + apply_mask = apply_mask.to(device=context.device, dtype=context.dtype).view(-1, 1, 1) + else: + apply_mask = 1.0 + + # ========== DEBUG: 打印 SCI 信息 ========== + _sci_call_count[0] += 1 + if _sci_call_count[0] % 100 == 1: + print(f"\n{'='*60}") + print(f"[SCI (Saliency-Controlled Intervention) #{_sci_call_count[0]}]") + print(f" timestep: {t_value:.2f}, t_ratio: {t_ratio:.4f}, omega: {omega:.4f}") + print(f" saliency per subject: {saliency.tolist()}") + print(f" delta shape: {delta.shape}") + print(f" delta norm per subject: {delta.norm(dim=-1).tolist()}") + print(f" token_delta shape: {token_delta.shape}") + print(f" context modification norm: {(token_delta.unsqueeze(0) * apply_mask).norm().item():.6f}") + print(f"{'='*60}\n") + + return context + token_delta.unsqueeze(0) * apply_mask + + +def interpolate_bboxes(bboxes: torch.Tensor, target_frames: int) -> torch.Tensor: + if bboxes.shape[2] == target_frames: + return bboxes + b, m, f, _ = bboxes.shape + coords = bboxes.reshape(b * m, f, 4).transpose(1, 2) + coords = F.interpolate(coords, size=target_frames, mode="linear", align_corners=True) + coords = coords.transpose(1, 2).reshape(b, m, target_frames, 4) + return coords + + +def build_layout_mask_from_bboxes( + bboxes: torch.Tensor, + grid_size: tuple[int, int, int], + image_size: tuple[int, int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + if bboxes is None: + return None + bboxes = bboxes.to(device=device, dtype=dtype) + b, m, f_layout, _ = bboxes.shape + f_grid, h_grid, w_grid = grid_size + height, width = image_size + layout = torch.zeros((b, m, f_grid, h_grid, w_grid), device=device, dtype=dtype) + for bi in range(b): + for mi in range(m): + for ti in range(f_layout): + pt = int(ti * f_grid / max(1, f_layout)) + pt = max(0, min(f_grid - 1, pt)) + x0, y0, x1, y1 = bboxes[bi, mi, ti] + x0 = float(x0) + y0 = float(y0) + x1 = float(x1) + y1 = float(y1) + if x1 <= x0 or y1 <= y0: + continue + px0 = int(math.floor(x0 / max(1.0, width) * w_grid)) + px1 = int(math.ceil(x1 / max(1.0, width) * w_grid)) + py0 = int(math.floor(y0 / max(1.0, height) * h_grid)) + py1 = int(math.ceil(y1 / max(1.0, height) * h_grid)) + px0 = max(0, min(w_grid, px0)) + px1 = max(0, min(w_grid, px1)) + py0 = max(0, min(h_grid, py0)) + py1 = max(0, min(h_grid, py1)) + if px1 <= px0 or py1 <= py0: + continue + layout[bi, mi, pt, py0:py1, px0:px1] = 1.0 + return layout.flatten(2) + + +_lam_attention_call_count = [0] # 使用列表以便在函数内修改 + +def lam_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_heads: int, + state: dict, +) -> Optional[torch.Tensor]: + subject_mask = state.get("subject_token_mask_lam") + if subject_mask is None: + subject_mask = state.get("subject_token_mask") + layout_mask = state.get("layout_mask") + state_token_mask = state.get("state_token_mask") + state_token_weights = state.get("state_token_weights") + state_scale = float(state.get("state_scale", 1.0)) + grid_shape = state.get("grid_shape") + enable_lam = bool(state.get("enable_lam", False)) + enable_state = state_token_mask is not None and state_token_weights is not None and grid_shape is not None + if not enable_lam and not enable_state: + return None + b, q_len, dim = q.shape + _, k_len, _ = k.shape + if enable_lam: + if subject_mask is None or layout_mask is None: + return None + if subject_mask.numel() == 0 or layout_mask.numel() == 0: + return None + if layout_mask.shape[-1] != q_len: + return None + if subject_mask.shape[-1] != k_len: + return None + if enable_state: + if state_token_mask.shape[-1] != k_len: + return None + head_dim = dim // num_heads + qh = q.view(b, q_len, num_heads, head_dim).transpose(1, 2) + kh = k.view(b, k_len, num_heads, head_dim).transpose(1, 2) + vh = v.view(b, k_len, num_heads, head_dim).transpose(1, 2) + attn_scores = torch.matmul(qh.float(), kh.float().transpose(-2, -1)) / math.sqrt(head_dim) + + # ========== DEBUG: 打印 attention map 信息 ========== + _lam_attention_call_count[0] += 1 + call_id = _lam_attention_call_count[0] + # 每100次调用打印一次,避免输出过多 + if call_id % 100 == 1: + print(f"\n{'='*60}") + print(f"[LAM Attention #{call_id}]") + print(f" Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}") + print(f" num_heads: {num_heads}, head_dim: {head_dim}") + print(f" attn_scores shape: {attn_scores.shape}") + print(f" attn_scores stats: min={attn_scores.min().item():.4f}, max={attn_scores.max().item():.4f}, mean={attn_scores.mean().item():.4f}") + if enable_lam and layout_mask is not None: + print(f" layout_mask shape: {layout_mask.shape}") + print(f" layout_mask sum per subject: {layout_mask.sum(dim=-1)}") + if subject_mask is not None: + print(f" subject_token_mask shape: {subject_mask.shape}") + print(f" subject_token_mask active tokens per subject: {subject_mask.sum(dim=-1).tolist()}") + if grid_shape is not None: + print(f" grid_shape (f, h, w): {grid_shape}") + print(f"{'='*60}") + bias = torch.zeros_like(attn_scores) + if enable_lam: + attn_max = attn_scores.max(dim=-1, keepdim=True).values + attn_min = attn_scores.min(dim=-1, keepdim=True).values + g_plus = attn_max - attn_scores + g_minus = attn_min - attn_scores + subject_mask = subject_mask.to(device=attn_scores.device) + layout_mask = layout_mask.to(device=attn_scores.device, dtype=attn_scores.dtype) + apply_mask = state.get("apply_mask") + if apply_mask is not None: + layout_mask = layout_mask * apply_mask.to(device=layout_mask.device, dtype=layout_mask.dtype).view(-1, 1, 1) + subject_any = subject_mask.any(dim=0) + for k_idx in range(subject_mask.shape[0]): + mask_k = subject_mask[k_idx] + if not mask_k.any(): + continue + mask_other = subject_any & (~mask_k) + mask_k = mask_k.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len) + mask_other = mask_other.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len) + g_k = g_plus * mask_k + g_minus * mask_other + attn_k = attn_scores[..., subject_mask[k_idx]].mean(dim=-1).mean(dim=1) + adapt_mask = attn_k >= attn_k.mean(dim=-1, keepdim=True) + layout_k = layout_mask[:, k_idx] + adapt_f = adapt_mask.to(layout_k.dtype) + inter = (adapt_f * layout_k).sum(dim=-1) + union = (adapt_f + layout_k - adapt_f * layout_k).sum(dim=-1) + iou = inter / union.clamp(min=1e-6) + strength = (1.0 - iou).view(b, 1, 1, 1) + bias = bias + g_k * strength * layout_k.view(b, 1, q_len, 1) + if enable_state: + f, h, w = grid_shape + if f * h * w != q_len: + return None + state_token_mask = state_token_mask.to(device=attn_scores.device) + state_indices = torch.nonzero(state_token_mask, as_tuple=False).flatten() + if state_indices.numel() == 0: + return None + weights = state_token_weights.to(device=attn_scores.device, dtype=attn_scores.dtype) + if weights.shape[1] != f: + return None + time_index = torch.arange(q_len, device=attn_scores.device) // (h * w) + weights_q = weights[:, time_index, :] + if weights_q.shape[-1] != state_indices.numel(): + return None + state_bias = torch.zeros((b, 1, q_len, k_len), device=attn_scores.device, dtype=attn_scores.dtype) + state_bias[:, :, :, state_indices] = weights_q.unsqueeze(1) * state_scale + bias = bias + state_bias + attn_probs = torch.softmax(attn_scores + bias, dim=-1).to(vh.dtype) + + # ========== DEBUG: 打印 attention probs 和 bias 信息 ========== + if _lam_attention_call_count[0] % 100 == 1: + print(f"\n[LAM Attention #{_lam_attention_call_count[0]} - After Bias]") + print(f" bias shape: {bias.shape}") + print(f" bias stats: min={bias.min().item():.4f}, max={bias.max().item():.4f}, mean={bias.mean().item():.4f}") + print(f" bias non-zero ratio: {(bias != 0).float().mean().item():.4f}") + print(f" attn_probs shape: {attn_probs.shape}") + print(f" attn_probs stats: min={attn_probs.min().item():.6f}, max={attn_probs.max().item():.6f}") + # 打印每个 subject 对应 token 的平均 attention weight + if subject_mask is not None: + for subj_idx in range(subject_mask.shape[0]): + mask_k = subject_mask[subj_idx] + if mask_k.any(): + # 计算所有 query 对该 subject tokens 的平均 attention + subj_attn = attn_probs[:, :, :, mask_k.to(attn_probs.device)].mean() + print(f" Subject {subj_idx} avg attention weight: {subj_attn.item():.6f}") + print(f"{'='*60}\n") + + out = torch.matmul(attn_probs, vh) + out = out.transpose(1, 2).reshape(b, q_len, dim) + return out + + +class CompAttnUnit(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "context": "context"}, + input_params_nega={"prompt": "negative_prompt", "context": "context"}, + output_params=("comp_attn_state",), + onload_model_names=("text_encoder",), + ) + + def _clean_text(self, pipe, text: str) -> str: + if getattr(pipe.tokenizer, "clean", None): + return pipe.tokenizer._clean(text) + return text + + def _tokenize_subject(self, pipe, text: str) -> torch.Tensor: + text = self._clean_text(pipe, text) + tokens = pipe.tokenizer.tokenizer(text, add_special_tokens=False, return_tensors="pt") + return tokens["input_ids"][0] + + def _normalize_bboxes(self, bboxes: Sequence) -> torch.Tensor: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32) + if bboxes.dim() == 2 and bboxes.shape[-1] == 4: + bboxes = bboxes.unsqueeze(0).unsqueeze(0) + elif bboxes.dim() == 3 and bboxes.shape[-1] == 4: + bboxes = bboxes.unsqueeze(0) + elif bboxes.dim() != 4 or bboxes.shape[-1] != 4: + raise ValueError(f"comp_attn_bboxes must be (..., 4), got shape {tuple(bboxes.shape)}") + return bboxes + + def process(self, pipe, prompt, context) -> dict: + config: Optional[CompAttnConfig] = getattr(pipe, "_comp_attn_config", None) + if context is None or prompt is None or config is None: + return {} + if not config.subjects: + return {} + negative_prompt = getattr(pipe, "_comp_attn_last_negative_prompt", None) + if (not config.apply_to_negative) and negative_prompt and prompt == negative_prompt: + return {} + pipe.load_models_to_device(self.onload_model_names) + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + prompt_ids = ids[0] + valid_len = int(mask[0].sum().item()) + indices_list = [] + valid_subjects = [] + for idx, subject in enumerate(config.subjects): + subject_ids = self._tokenize_subject(pipe, subject) + indices = find_subsequence_indices(prompt_ids, subject_ids, valid_len) + if not indices: + print(f"Comp-Attn: subject tokens not found in prompt: {subject}") + continue + indices_list.append(indices) + valid_subjects.append(idx) + if not indices_list: + return {} + subject_token_mask = build_subject_token_mask(indices_list, prompt_ids.shape[0]).to(device=context.device) + mask_float = subject_token_mask.to(dtype=context.dtype) + denom = mask_float.sum(dim=1, keepdim=True).clamp(min=1) + prompt_vecs = (mask_float @ context[0]) / denom + anchor_vecs = [] + for idx in valid_subjects: + subject = config.subjects[idx] + sub_ids, sub_mask = pipe.tokenizer(subject, return_mask=True, add_special_tokens=True) + sub_ids = sub_ids.to(pipe.device) + sub_mask = sub_mask.to(pipe.device) + emb = pipe.text_encoder(sub_ids, sub_mask) + pooled = (emb * sub_mask.unsqueeze(-1)).sum(dim=1) / sub_mask.sum(dim=1, keepdim=True).clamp(min=1) + anchor_vecs.append(pooled) + anchor_vecs = torch.cat(anchor_vecs, dim=0) + saliency = compute_saliency(prompt_vecs.float(), anchor_vecs.float(), float(config.temperature)).to(prompt_vecs.dtype) + delta = compute_delta(anchor_vecs.to(prompt_vecs.dtype)) + bboxes = None + state_vectors = None + state_weights = None + state_len = 0 + if config.bboxes is not None: + bboxes = self._normalize_bboxes(config.bboxes) + if bboxes.shape[1] >= len(config.subjects): + bboxes = bboxes[:, valid_subjects] + if bboxes.shape[1] != len(valid_subjects): + print("Comp-Attn: bboxes subject count mismatch, disable LAM") + bboxes = None + if bboxes is not None and config.interpolate and getattr(pipe, "_comp_attn_num_frames", None) is not None: + bboxes = interpolate_bboxes(bboxes, int(pipe._comp_attn_num_frames)) + if config.state_texts is not None and config.state_weights is not None: + state_texts = config.state_texts + if len(valid_subjects) != len(config.subjects): + subject_names = [config.subjects[i] for i in valid_subjects] + state_texts = [state_texts[i] for i in valid_subjects] + else: + subject_names = list(config.subjects) + if len(state_texts) != len(subject_names): + raise ValueError("state_texts must align with subjects") + state_count = len(state_texts[0]) + for row in state_texts: + if len(row) != state_count: + raise ValueError("state_texts must have the same number of states per subject") + phrases = [] + for subject, states in zip(subject_names, state_texts): + for state in states: + phrases.append(config.state_template.format(subject=subject, state=state)) + ids, mask = pipe.tokenizer(phrases, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + emb = pipe.text_encoder(ids, mask) + pooled = (emb * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1) + state_vectors = pooled.to(dtype=prompt_vecs.dtype, device="cpu") + state_len = state_vectors.shape[0] + weights = torch.as_tensor(config.state_weights, dtype=torch.float32) + if weights.dim() == 3: + weights = weights.unsqueeze(0) + if weights.dim() != 4: + raise ValueError("state_weights must be (M,F,S) or (B,M,F,S)") + if weights.shape[1] >= len(config.subjects) and len(valid_subjects) != len(config.subjects): + weights = weights[:, valid_subjects] + if weights.shape[1] != len(subject_names) or weights.shape[3] != state_count: + raise ValueError("state_weights shape does not match state_texts") + weights = weights[:, :len(subject_names)] + weights = weights.permute(0, 2, 1, 3).contiguous() + weights = weights.reshape(weights.shape[0], weights.shape[1], weights.shape[2] * weights.shape[3]) + state_weights = weights.to(device="cpu") + state = { + "enable_sci": bool(config.enable_sci), + "enable_lam": bool(config.enable_lam) and bboxes is not None, + "subject_token_mask": subject_token_mask, + "saliency": saliency, + "delta": delta, + "layout_bboxes": bboxes, + "state_vectors": state_vectors, + "state_weights": state_weights, + "state_scale": float(config.state_scale), + "prompt_len": int(prompt_ids.shape[0]), + "state_len": int(state_len), + "timestep_scale": 1000.0, + "apply_to_negative": bool(config.apply_to_negative), + } + if negative_prompt and prompt == negative_prompt: + pipe._comp_attn_state_neg = state + else: + pipe._comp_attn_state_pos = state + return {"comp_attn_state": state} + + +class CompAttnMergeUnit(PipelineUnit): + def __init__(self): + super().__init__(input_params=("cfg_merge",), output_params=("comp_attn_state",)) + + def process(self, pipe, cfg_merge) -> dict: + if not cfg_merge: + return {} + state_pos = getattr(pipe, "_comp_attn_state_pos", None) + state_neg = getattr(pipe, "_comp_attn_state_neg", None) + merged = state_pos or state_neg + if merged is None: + return {} + merged = dict(merged) + apply_to_negative = bool(merged.get("apply_to_negative", False)) + merged["apply_mask"] = torch.tensor([1.0, 1.0 if apply_to_negative else 0.0]) + return {"comp_attn_state": merged} + + +def patch_cross_attention(pipe) -> None: + for block in pipe.dit.blocks: + cross_attn = block.cross_attn + if getattr(cross_attn, "_comp_attn_patched", False): + continue + orig_forward = cross_attn.forward + + def forward_with_lam(self, x, y, _orig=orig_forward, _pipe=pipe): + state = getattr(_pipe, "_comp_attn_runtime_state", None) + enable_lam = bool(state.get("enable_lam", False)) if state else False + enable_state = bool(state.get("state_token_weights") is not None) if state else False + if state is None or (not enable_lam and not enable_state): + return _orig(x, y) + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + lam_out = lam_attention(q, k, v, self.num_heads, state) + if lam_out is None: + out = self.attn(q, k, v) + else: + out = lam_out + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + img_out = self.attn(q, k_img, v_img) + out = out + img_out + return self.o(out) + + cross_attn.forward = forward_with_lam.__get__(cross_attn, cross_attn.__class__) + cross_attn._comp_attn_patched = True + + +def get_grid_from_latents(latents: torch.Tensor, patch_size: tuple[int, int, int]) -> tuple[int, int, int]: + f = latents.shape[2] // patch_size[0] + h = latents.shape[3] // patch_size[1] + w = latents.shape[4] // patch_size[2] + return f, h, w + + +def wrap_model_fn(pipe) -> None: + if getattr(pipe, "_comp_attn_model_fn_patched", False): + return + orig_model_fn = pipe.model_fn + + def model_fn_wrapper(*args, **kwargs): + comp_attn_state = kwargs.pop("comp_attn_state", None) + height = kwargs.get("height") + width = kwargs.get("width") + num_frames = kwargs.get("num_frames") + if num_frames is not None: + pipe._comp_attn_num_frames = num_frames + if comp_attn_state is None: + return orig_model_fn(*args, **kwargs) + latents = kwargs.get("latents") + timestep = kwargs.get("timestep") + context = kwargs.get("context") + clip_feature = kwargs.get("clip_feature") + reference_latents = kwargs.get("reference_latents") + state_vectors = comp_attn_state.get("state_vectors") + state_weights = comp_attn_state.get("state_weights") + state_len = int(comp_attn_state.get("state_len", 0)) + prompt_len = int(comp_attn_state.get("prompt_len", context.shape[1] if context is not None else 0)) + if context is not None and timestep is not None: + context = apply_sci(context, comp_attn_state, timestep) + if state_vectors is not None and state_len > 0: + state_vec = state_vectors.to(device=context.device, dtype=context.dtype) + if state_vec.dim() == 2: + state_vec = state_vec.unsqueeze(0) + if state_vec.shape[0] != context.shape[0]: + state_vec = state_vec.repeat(context.shape[0], 1, 1) + context = torch.cat([context, state_vec], dim=1) + kwargs["context"] = context + subject_mask = comp_attn_state.get("subject_token_mask") + if subject_mask is not None: + clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0 + pad_clip = torch.zeros((subject_mask.shape[0], clip_len), dtype=torch.bool) + pad_state = torch.zeros((subject_mask.shape[0], state_len), dtype=torch.bool) + comp_attn_state["subject_token_mask_lam"] = torch.cat([pad_clip, subject_mask.cpu(), pad_state], dim=1) + if state_vectors is not None and state_len > 0: + clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0 + pad_prompt = torch.zeros((state_len, clip_len + prompt_len), dtype=torch.bool) + ones_state = torch.ones((state_len, state_len), dtype=torch.bool) + state_token_mask = torch.cat([pad_prompt, ones_state], dim=1).any(dim=0) + comp_attn_state["state_token_mask"] = state_token_mask + if latents is not None and height is not None and width is not None: + f, h, w = get_grid_from_latents(latents, pipe.dit.patch_size) + if comp_attn_state.get("enable_lam", False): + q_len = f * h * w + if reference_latents is not None: + q_len = (f + 1) * h * w + layout_mask = comp_attn_state.get("layout_mask") + layout_shape = comp_attn_state.get("layout_shape") + if layout_mask is None or layout_shape != (latents.shape[0], q_len): + layout_mask = build_layout_mask_from_bboxes( + comp_attn_state.get("layout_bboxes"), + (f, h, w), + (int(height), int(width)), + device=latents.device, + dtype=latents.dtype, + ) + if reference_latents is not None: + pad = torch.zeros((layout_mask.shape[0], layout_mask.shape[1], h * w), device=latents.device, dtype=latents.dtype) + layout_mask = torch.cat([pad, layout_mask], dim=-1) + if layout_mask.shape[0] != latents.shape[0]: + layout_mask = layout_mask.repeat(latents.shape[0], 1, 1) + comp_attn_state["layout_mask"] = layout_mask + comp_attn_state["layout_shape"] = (latents.shape[0], q_len) + if state_weights is not None: + weights = state_weights.to(device=latents.device, dtype=latents.dtype) + if weights.shape[0] != latents.shape[0]: + weights = weights.repeat(latents.shape[0], 1, 1) + if weights.shape[1] != f: + weights = weights.transpose(1, 2) + weights = F.interpolate(weights, size=f, mode="linear", align_corners=True) + weights = weights.transpose(1, 2) + if reference_latents is not None: + pad = torch.zeros((weights.shape[0], 1, weights.shape[2]), device=weights.device, dtype=weights.dtype) + weights = torch.cat([pad, weights], dim=1) + f = f + 1 + comp_attn_state["state_token_weights"] = weights + comp_attn_state["grid_shape"] = (f, h, w) + if ( + latents is not None + and latents.shape[0] == 2 + and not comp_attn_state.get("apply_to_negative", False) + and "apply_mask" not in comp_attn_state + ): + comp_attn_state["apply_mask"] = torch.tensor([1.0, 0.0], device=latents.device, dtype=latents.dtype) + pipe._comp_attn_runtime_state = comp_attn_state + try: + return orig_model_fn(*args, **kwargs) + finally: + pipe._comp_attn_runtime_state = None + + pipe.model_fn = model_fn_wrapper + pipe._comp_attn_model_fn_patched = True diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..be2ee5876d2f59de7f03617e1b67e4c52d59192d --- /dev/null +++ b/diffsynth/models/dinov3_image_encoder.py @@ -0,0 +1,94 @@ +from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast +from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig +import torch + + +class DINOv3ImageEncoder(DINOv3ViTModel): + def __init__(self): + config = DINOv3ViTConfig( + architectures = [ + "DINOv3ViTModel" + ], + attention_dropout = 0.0, + drop_path_rate = 0.0, + dtype = "float32", + hidden_act = "silu", + hidden_size = 4096, + image_size = 224, + initializer_range = 0.02, + intermediate_size = 8192, + key_bias = False, + layer_norm_eps = 1e-05, + layerscale_value = 1.0, + mlp_bias = True, + model_type = "dinov3_vit", + num_attention_heads = 32, + num_channels = 3, + num_hidden_layers = 40, + num_register_tokens = 4, + patch_size = 16, + pos_embed_jitter = None, + pos_embed_rescale = 2.0, + pos_embed_shift = None, + proj_bias = True, + query_bias = False, + rope_theta = 100.0, + transformers_version = "4.56.1", + use_gated_mlp = True, + value_bias = False + ) + super().__init__(config) + self.processor = DINOv3ViTImageProcessorFast( + crop_size = None, + data_format = "channels_first", + default_to_square = True, + device = None, + disable_grouping = None, + do_center_crop = None, + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.485, + 0.456, + 0.406 + ], + image_processor_type = "DINOv3ViTImageProcessorFast", + image_std = [ + 0.229, + 0.224, + 0.225 + ], + input_data_format = None, + resample = 2, + rescale_factor = 0.00392156862745098, + return_tensors = None, + size = { + "height": 224, + "width": 224 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) + bool_masked_pos = None + head_mask = None + + pixel_values = pixel_values.to(torch_dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module( + hidden_states, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return pooled_output diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..a08c579a77c8aa1c2b35c0c168517335a636097a --- /dev/null +++ b/diffsynth/models/flux2_dit.py @@ -0,0 +1,1057 @@ +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch, math +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AdaLayerNormContinuous(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + +def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + x = self.linear_out(x) + return x + + +class Flux2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Flux2Attention(torch.nn.Module): + _default_processor_cls = Flux2AttnProcessor + _available_processors = [Flux2AttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2ParallelSelfAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2ParallelSelfAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Parallel in (QKV + MLP in) projection + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +class Flux2ParallelSelfAttention(torch.nn.Module): + """ + Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. + + This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) + input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B + paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + """ + + _default_processor_cls = Flux2ParallelSelfAttnProcessor + _available_processors = [Flux2ParallelSelfAttnProcessor] + # Does not support QKV fusion as the QKV projections are always fused + _supports_qkv_fusion = False + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + # Fused QKV projections + MLP input projection + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + # Fused attention output projection + MLP output projection + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + processor=Flux2ParallelSelfAttnProcessor(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + split_hidden_states: bool = False, + text_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + else: + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + processor=Flux2AttnProcessor(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + # Modulation parameters shape: [1, 1, self.dim] + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + super().__init__() + + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + return time_guidance_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2DiT(torch.nn.Module): + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + # 4. Input projections + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ) -> Union[torch.Tensor]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Concatenate text and image streams for single-block inference + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Remove text tokens from concatenated stream + hidden_states = hidden_states[:, num_txt_tokens:, ...] + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return output diff --git a/diffsynth/models/flux2_text_encoder.py b/diffsynth/models/flux2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c68411f3160655ebefd49dfb6424b19373a301 --- /dev/null +++ b/diffsynth/models/flux2_text_encoder.py @@ -0,0 +1,58 @@ +from transformers import Mistral3ForConditionalGeneration, Mistral3Config + + +class Flux2TextEncoder(Mistral3ForConditionalGeneration): + def __init__(self): + config = Mistral3Config(**{ + "architectures": [ + "Mistral3ForConditionalGeneration" + ], + "dtype": "bfloat16", + "image_token_index": 10, + "model_type": "mistral3", + "multimodal_projector_bias": False, + "projector_hidden_act": "gelu", + "spatial_merge_size": 2, + "text_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 131072, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000000.0, + "sliding_window": None, + "use_cache": True, + "vocab_size": 131072 + }, + "transformers_version": "4.57.1", + "vision_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 1024, + "image_size": 1540, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "pixtral", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "rope_theta": 10000.0 + }, + "vision_feature_layer": -1 + }) + super().__init__(config) + + def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs): + return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs) + diff --git a/diffsynth/models/flux2_vae.py b/diffsynth/models/flux2_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c7904b17618cc3c0811e42fde0f80ecd8f15f7ee --- /dev/null +++ b/diffsynth/models/flux2_vae.py @@ -0,0 +1,2322 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Optional, Tuple, Union, Callable + +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +import inspect + +ACT2CLS = { + "swish": nn.SiLU, + "silu": nn.SiLU, + "mish": nn.Mish, + "gelu": nn.GELU, + "relu": nn.ReLU, +} + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACT2CLS: + return ACT2CLS[act_fn]() + else: + raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a + stronger conditioning with scale and shift. + kernel (`torch.Tensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + if time_embedding_norm == "ada_group": + raise ValueError( + "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", + ) + if time_embedding_norm == "spatial": + raise ValueError( + "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", + ) + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError( + f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" + ) + time_scale, time_shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + time_scale) + time_shift + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor.contiguous()) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 + # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + # upsample_nearest_nhwc also fails when the number of output elements is large + # https://github.com/pytorch/pytorch/issues/141831 + scale_factor = ( + 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])]) + ) + if hidden_states.numel() * scale_factor > pow(2, 31): + hidden_states = hidden_states.contiguous() + + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # Cast back to original dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + # from .normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "layer_norm": + self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # Wan applies qk norm across all heads + # Wan also doesn't apply a q norm + self.norm_added_q = None + self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_xla_flash_attention( + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, + ) -> None: + r""" + Set whether to use xla flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + """ + if use_xla_flash_attention: + if not is_torch_xla_available: + raise "torch_xla is not available" + elif is_torch_xla_version("<", "2.3"): + raise "flash attention pallas kernel is supported from torch_xla version 2.3" + elif is_spmd() and is_torch_xla_version("<", "2.4"): + raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" + else: + if is_flux: + processor = XLAFluxFlashAttnProcessor2_0(partition_spec) + else: + processor = XLAFlashAttnProcessor2_0(partition_spec) + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + is_ip_adapter = hasattr(self, "processor") and isinstance( + self.processor, + (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), + ) + is_joint_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + JointAttnProcessor2_0, + XFormersJointAttnProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + elif is_ip_adapter: + processor = IPAdapterXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) + elif is_joint_processor: + processor = XFormersJointAttnProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_ip_adapter: + processor = IPAdapterAttnProcessor2_0( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0, output_size=attention_mask.shape[0] * head_size + ) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=1, output_size=attention_mask.shape[1] * head_size + ) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock2D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + # attention_head_dim=output_channel, + # temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # down + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock2D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + # prev_output_channel=prev_output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + # attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + latent_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Flux2VAE(torch.nn.Module): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels: Tuple[int, ...] = ( + 128, + 256, + 512, + 512, + ), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 32, + norm_num_groups: int = 32, + sample_size: int = 1024, # YiYi notes: not sure + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + batch_norm_eps: float = 1e-4, + batch_norm_momentum: float = 0.1, + patch_size: Tuple[int, int] = (2, 2), + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) + + self.use_slicing = False + self.use_tiling = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ): + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + + h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2) + h = h[:, :128] + latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype) + latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( + h.device, h.dtype + ) + h = (h - latents_bn_mean) / latents_bn_std + return h + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return dec + + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ): + latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype) + latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( + z.device, z.dtype + ) + z = z * latents_bn_std + latents_bn_mean + z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2) + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True): + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + return moments + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ): + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return dec diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb1138bb74f7b55c6e92ac312098e4168829f0a --- /dev/null +++ b/diffsynth/models/flux_controlnet.py @@ -0,0 +1,384 @@ +import torch +from einops import rearrange, repeat +from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm +# from .utils import hash_state_dict_keys, init_weights_on_device +from contextlib import contextmanager + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + +class FluxControlNet(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(64, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)]) + + self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)]) + self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)]) + + self.mode_dict = mode_dict + self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None + self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072) + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states): + if len(res_stack) == 0: + return [torch.zeros_like(hidden_states)] * num_blocks + interval = (num_blocks + len(res_stack) - 1) // len(res_stack) + aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)] + return aligned_res_stack + + + def forward( + self, + hidden_states, + controlnet_conditioning, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + processor_id=None, + tiled=False, tile_size=128, tile_stride=64, + **kwargs + ): + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) + if self.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) + prompt_emb = self.context_embedder(prompt_emb) + if self.controlnet_mode_embedder is not None: # Different from FluxDiT + processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int) + processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device) + prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1) + text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + + hidden_states = self.patchify(hidden_states) + hidden_states = self.x_embedder(hidden_states) + controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT + + controlnet_res_stack = [] + for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_res_stack.append(controlnet_block(hidden_states)) + + controlnet_single_res_stack = [] + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:])) + + controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:]) + controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:]) + + return controlnet_res_stack, controlnet_single_res_stack + + + # @staticmethod + # def state_dict_converter(): + # return FluxControlNetStateDictConverter() + + def quantize(self): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + bias = None + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) + return weight, bias + + class quantized_layer: + class QLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight,bias= cast_bias_weight(self,input) + return torch.nn.functional.linear(input,weight,bias) + + class QRMSNorm(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,hidden_states,**kwargs): + weight= cast_weight(self.module,hidden_states) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) * weight + return hidden_states + + class QEmbedding(torch.nn.Embedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight= cast_weight(self,input) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def replace_layer(model): + for name, module in model.named_children(): + if isinstance(module,quantized_layer.QRMSNorm): + continue + if isinstance(module, torch.nn.Linear): + with init_weights_on_device(): + new_layer = quantized_layer.QLinear(module.in_features,module.out_features) + new_layer.weight = module.weight + if module.bias is not None: + new_layer.bias = module.bias + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + if hasattr(module,"quantized"): + continue + module.quantized= True + new_layer = quantized_layer.QRMSNorm(module) + setattr(model, name, new_layer) + elif isinstance(module,torch.nn.Embedding): + rows, cols = module.weight.shape + new_layer = quantized_layer.QEmbedding( + num_embeddings=rows, + embedding_dim=cols, + _weight=module.weight, + # _freeze=module.freeze, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse) + setattr(model, name, new_layer) + else: + replace_layer(module) + + replace_layer(self) + + + +class FluxControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + hash_value = hash_state_dict_keys(state_dict) + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + if hash_value == "78d18b9101345ff695f312e7e62538c0": + extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}} + elif hash_value == "b001c89139b5f053c715fe772362dd2a": + extra_kwargs = {"num_single_blocks": 0} + elif hash_value == "52357cb26250681367488a8954c271e8": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4} + elif hash_value == "0cfd1740758423a2a854d67c136d1e8c": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1} + elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10} + elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0} + else: + extra_kwargs = {} + return state_dict_, extra_kwargs + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..51a6e7f049d15e764383487c0edbf7da839b5918 --- /dev/null +++ b/diffsynth/models/flux_dit.py @@ -0,0 +1,395 @@ +import torch +from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm +from einops import rearrange + + +def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size, num_tokens = hidden_states.shape[0:2] + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + +class RoPEEmbedding(torch.nn.Module): + def __init__(self, dim, theta, axes_dim): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + + + +class FluxJointAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.only_out_a = only_out_a + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + self.norm_q_b = RMSNorm(head_dim, eps=1e-6) + self.norm_k_b = RMSNorm(head_dim, eps=1e-6) + + self.a_to_out = torch.nn.Linear(dim_a, dim_a) + if not only_out_a: + self.b_to_out = torch.nn.Linear(dim_b, dim_b) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states_a.shape[0] + + # Part A + qkv_a = self.a_to_qkv(hidden_states_a) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v_a = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + # Part B + qkv_b = self.b_to_qkv(hidden_states_b) + qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_b, k_b, v_b = qkv_b.chunk(3, dim=1) + q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b) + + q = torch.concat([q_b, q_a], dim=2) + k = torch.concat([k_b, k_a], dim=2) + v = torch.concat([v_b, v_a], dim=2) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] + if ipadapter_kwargs_list is not None: + hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list) + hidden_states_a = self.a_to_out(hidden_states_a) + if self.only_out_a: + return hidden_states_a + else: + hidden_states_b = self.b_to_out(hidden_states_b) + return hidden_states_a, hidden_states_b + + + +class FluxJointTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.norm1_a = AdaLayerNorm(dim) + self.norm1_b = AdaLayerNorm(dim) + + self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads) + + self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_a = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_b = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + + # Attention + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + + # Part A + hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a + hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) + + # Part B + hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b + norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b + hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) + + return hidden_states_a, hidden_states_b + + + +class FluxSingleAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def forward(self, hidden_states, image_rotary_emb): + batch_size = hidden_states.shape[0] + + qkv_a = self.a_to_qkv(hidden_states) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + q, k = self.apply_rope(q_a, k_a, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + return hidden_states + + + +class AdaLayerNormSingle(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, 3 * dim, bias=True) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + + +class FluxSingleTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = dim // num_attention_heads + self.dim = dim + + self.norm = AdaLayerNormSingle(dim) + self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4)) + self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6) + + self.proj_out = torch.nn.Linear(dim * 5, dim) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states.shape[0] + + qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q, k, v = qkv.chunk(3, dim=1) + q, k = self.norm_q_a(q), self.norm_k_a(k) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + if ipadapter_kwargs_list is not None: + hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list) + return hidden_states + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + residual = hidden_states_a + norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) + hidden_states_a = self.to_qkv_mlp(norm_hidden_states) + attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] + + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") + + hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) + hidden_states_a = residual + hidden_states_a + + return hidden_states_a, hidden_states_b + + + +class AdaLayerNormContinuous(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, dim * 2, bias=True) + self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) + + def forward(self, x, conditioning): + emb = self.linear(self.silu(conditioning)) + shift, scale = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None] + shift[:, None] + return x + + + +class FluxDiT(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) + + self.final_norm_out = AdaLayerNormContinuous(3072) + self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def unpatchify(self, hidden_states, height, width): + hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) + return hidden_states + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask + + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim): + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, + use_gradient_checkpointing=False, + **kwargs + ): + # (Deprecated) The real forward is in `pipelines.flux_image`. + return None diff --git a/diffsynth/models/flux_infiniteyou.py b/diffsynth/models/flux_infiniteyou.py new file mode 100644 index 0000000000000000000000000000000000000000..861538a4b02fb6a52edee662b6efcd60f78f6916 --- /dev/null +++ b/diffsynth/models/flux_infiniteyou.py @@ -0,0 +1,129 @@ +import math +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class InfiniteYouImageProjector(nn.Module): + + def __init__( + self, + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=8, + embedding_dim=512, + output_dim=4096, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + latents = latents.to(dtype=x.dtype, device=x.device) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + @staticmethod + def state_dict_converter(): + return FluxInfiniteYouImageProjectorStateDictConverter() + + +class FluxInfiniteYouImageProjectorStateDictConverter: + + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict['image_proj'] diff --git a/diffsynth/models/flux_ipadapter.py b/diffsynth/models/flux_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..31176fc2c2a508388502b45dc27e4d2218f16eec --- /dev/null +++ b/diffsynth/models/flux_ipadapter.py @@ -0,0 +1,110 @@ +from .general_modules import RMSNorm +from transformers import SiglipVisionModel, SiglipVisionConfig +import torch + + +class SiglipVisionModelSO400M(SiglipVisionModel): + def __init__(self): + config = SiglipVisionConfig( + hidden_size=1152, + image_size=384, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + patch_size=14, + architectures=["SiglipModel"], + initializer_factor=1.0, + torch_dtype="float32", + transformers_version="4.37.0.dev0" + ) + super().__init__(config) + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +class IpAdapterModule(torch.nn.Module): + def __init__(self, num_attention_heads, attention_head_dim, input_dim): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + output_dim = num_attention_heads * attention_head_dim + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False) + + + def forward(self, hidden_states): + batch_size = hidden_states.shape[0] + # ip_k + ip_k = self.to_k_ip(hidden_states) + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_k = self.norm_added_k(ip_k) + # ip_v + ip_v = self.to_v_ip(hidden_states) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + return ip_k, ip_v + + +class FluxIpAdapter(torch.nn.Module): + def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57): + super().__init__() + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)]) + self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens) + self.set_adapter() + + def set_adapter(self): + self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for block_id in self.call_block_id: + ipadapter_id = self.call_block_id[block_id] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + ip_kv_dict[block_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + @staticmethod + def state_dict_converter(): + return FluxIpAdapterStateDictConverter() + + +class FluxIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..13589b0611f3140479ef4faa3b7a29371caa447b --- /dev/null +++ b/diffsynth/models/flux_lora_encoder.py @@ -0,0 +1,521 @@ +import torch +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ + + + +class LoRALayerBlock(torch.nn.Module): + def __init__(self, L, dim_in, dim_out): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) + self.layer_norm = torch.nn.LayerNorm(dim_out) + + def forward(self, lora_A, lora_B): + x = self.x @ lora_A.T @ lora_B.T + x = self.layer_norm(x) + return x + + +class LoRAEmbedder(torch.nn.Module): + def __init__(self, lora_patterns=None, L=1, out_dim=2048): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1]) + self.model_dict = torch.nn.ModuleDict(model_dict) + + proj_dict = {} + for lora_pattern in lora_patterns: + layer_type, dim = lora_pattern["type"], lora_pattern["dim"] + if layer_type not in proj_dict: + proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim) + self.proj_dict = torch.nn.ModuleDict(proj_dict) + + self.lora_patterns = lora_patterns + + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), + "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + return lora_patterns + + def forward(self, lora): + lora_emb = [] + for lora_pattern in self.lora_patterns: + name, layer_type = lora_pattern["name"], lora_pattern["type"] + lora_A = lora[name + ".lora_A.weight"] + lora_B = lora[name + ".lora_B.weight"] + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) + lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) + lora_emb.append(lora_out) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + +class FluxLoRAEncoder(torch.nn.Module): + def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1): + super().__init__() + self.num_embeds_per_lora = num_embeds_per_lora + # embedder + self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)]) + + # special embedding + self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim)) + self.num_special_embeds = num_special_embeds + + # final layer + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + self.final_linear = torch.nn.Linear(embed_dim, embed_dim) + + def forward(self, lora): + lora_embeds = self.embedder(lora) + special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device) + embeds = torch.concat([special_embeds, lora_embeds], dim=1) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds) + embeds = embeds[:, :self.num_special_embeds] + embeds = self.final_layer_norm(embeds) + embeds = self.final_linear(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return FluxLoRAEncoderStateDictConverter() + + +class FluxLoRAEncoderStateDictConverter: + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/flux_lora_patcher.py b/diffsynth/models/flux_lora_patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8fc8cea03bbbc658d8e1869432d903b6ae7ce9 --- /dev/null +++ b/diffsynth/models/flux_lora_patcher.py @@ -0,0 +1,306 @@ +import torch, math +from ..core.loader import load_state_dict +from typing import Union + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_B." not in key: + continue + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + updated_num = 0 + lora_name_dict = self.get_name_dict(state_dict_lora) + for name, module in model.named_modules(): + if name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict = module.state_dict() + state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict) + updated_num += 1 + print(f"{updated_num} tensors are updated by LoRA.") + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().load(model, state_dict_lora, alpha) + + + def convert_state_dict(self,state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class FluxLoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) diff --git a/diffsynth/models/flux_text_encoder_clip.py b/diffsynth/models/flux_text_encoder_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1425423ce6d1df946198a16a7e96078ab8fed807 --- /dev/null +++ b/diffsynth/models/flux_text_encoder_clip.py @@ -0,0 +1,112 @@ +import torch + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FluxTextEncoderClip(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=2, extra_mask=None): + embeds = self.token_embedding(input_ids) + embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device) + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + if extra_mask is not None: + attn_mask[:, extra_mask[0]==0] = float("-inf") + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + hidden_states = embeds + embeds = self.final_layer_norm(embeds) + pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] + return pooled_embeds, hidden_states diff --git a/diffsynth/models/flux_text_encoder_t5.py b/diffsynth/models/flux_text_encoder_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..ee72e4a89b2089b62c6ea86c4aef91755ee9ee9a --- /dev/null +++ b/diffsynth/models/flux_text_encoder_t5.py @@ -0,0 +1,43 @@ +import torch +from transformers import T5EncoderModel, T5Config + + +class FluxTextEncoderT5(T5EncoderModel): + def __init__(self): + config = T5Config(**{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "dtype": "bfloat16", + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": True, + "is_gated_act": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": True, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": False, + "transformers_version": "4.57.1", + "use_cache": True, + "vocab_size": 32128 + }) + super().__init__(config) + + def forward(self, input_ids): + outputs = super().forward(input_ids=input_ids) + prompt_emb = outputs.last_hidden_state + return prompt_emb diff --git a/diffsynth/models/flux_vae.py b/diffsynth/models/flux_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..5eabeaee6ad54f0b1c02f1b24cd2ccd84d0238bf --- /dev/null +++ b/diffsynth/models/flux_vae.py @@ -0,0 +1,451 @@ +import torch +from einops import rearrange, repeat + + +class TileWorker: + def __init__(self): + pass + + + def mask(self, height, width, border_width): + # Create a mask with shape (height, width). + # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. + x = torch.arange(height).repeat(width, 1).T + y = torch.arange(width).repeat(height, 1) + mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values + mask = (mask / border_width).clip(0, 1) + return mask + + + def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): + # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) + batch_size, channel, _, _ = model_input.shape + model_input = model_input.to(device=tile_device, dtype=tile_dtype) + unfold_operator = torch.nn.Unfold( + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + model_input = unfold_operator(model_input) + model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) + + return model_input + + + def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): + # Call y=forward_fn(x) for each tile + tile_num = model_input.shape[-1] + model_output_stack = [] + + for tile_id in range(0, tile_num, tile_batch_size): + + # process input + tile_id_ = min(tile_id + tile_batch_size, tile_num) + x = model_input[:, :, :, :, tile_id: tile_id_] + x = x.to(device=inference_device, dtype=inference_dtype) + x = rearrange(x, "b c h w n -> (n b) c h w") + + # process output + y = forward_fn(x) + y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) + y = y.to(device=tile_device, dtype=tile_dtype) + model_output_stack.append(y) + + model_output = torch.concat(model_output_stack, dim=-1) + return model_output + + + def io_scale(self, model_output, tile_size): + # Determine the size modification happened in forward_fn + # We only consider the same scale on height and width. + io_scale = model_output.shape[2] / tile_size + return io_scale + + + def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): + # The reversed function of tile + mask = self.mask(tile_size, tile_size, border_width) + mask = mask.to(device=tile_device, dtype=tile_dtype) + mask = rearrange(mask, "h w -> 1 1 h w 1") + model_output = model_output * mask + + fold_operator = torch.nn.Fold( + output_size=(height, width), + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) + model_output = rearrange(model_output, "b c h w n -> b (c h w) n") + model_output = fold_operator(model_output) / fold_operator(mask) + + return model_output + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + inference_device, inference_dtype = model_input.device, model_input.dtype + height, width = model_input.shape[2], model_input.shape[3] + border_width = int(tile_stride*0.5) if border_width is None else border_width + + # tile + model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) + + # inference + model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) + + # resize + io_scale = self.io_scale(model_output, tile_size) + height, width = int(height*io_scale), int(width*io_scale) + tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) + border_width = int(border_width*io_scale) + + # untile + model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) + + # Done! + model_output = model_output.to(device=inference_device, dtype=inference_dtype) + return model_output + + +class ConvAttention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q) + self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + q = self.to_q(conv_input) + q = rearrange(q[:, :, :, 0], "B C L -> B L C") + conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1") + k = self.to_k(conv_input) + v = self.to_v(conv_input) + k = rearrange(k[:, :, :, 0], "B C L -> B L C") + v = rearrange(v[:, :, :, 0], "B C L -> B L C") + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + hidden_states = self.to_out(conv_input) + hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C") + + return hidden_states + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class VAEAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + + if use_conv_attention: + self.transformer_blocks = torch.nn.ModuleList([ + ConvAttention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + else: + self.transformer_blocks = torch.nn.ModuleList([ + Attention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class ResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x = hidden_states + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb)[:, :, None, None] + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +class UpSampler(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class DownSampler(torch.nn.Module): + def __init__(self, channels, padding=1, extra_padding=False): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) + self.extra_padding = extra_padding + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + if self.extra_padding: + hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class FluxVAEDecoder(torch.nn.Module): + def __init__(self, use_conv_attention=True): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x + + self.blocks = torch.nn.ModuleList([ + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention), + ResnetBlock(512, 512, eps=1e-6), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + UpSampler(256), + # UpDecoderBlock2D + ResnetBlock(256, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = sample / self.scaling_factor + self.shift_factor + hidden_states = self.conv_in(hidden_states) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class FluxVAEEncoder(torch.nn.Module): + def __init__(self, use_conv_attention=True): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # DownEncoderBlock2D + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + DownSampler(128, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(128, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + DownSampler(256, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(256, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + DownSampler(512, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention), + ResnetBlock(512, 512, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = self.conv_in(sample) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + hidden_states = hidden_states[:, :16] + hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor + + return hidden_states + + def encode_video(self, sample, batch_size=8): + B = sample.shape[0] + hidden_states = [] + + for i in range(0, sample.shape[2], batch_size): + + j = min(i + batch_size, sample.shape[2]) + sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W") + + hidden_states_batch = self(sample_batch) + hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B) + + hidden_states.append(hidden_states_batch) + + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py new file mode 100644 index 0000000000000000000000000000000000000000..549dbc93b41343a42266af11584e2e7d39a17cd6 --- /dev/null +++ b/diffsynth/models/flux_value_control.py @@ -0,0 +1,56 @@ +import torch +from .general_modules import TemporalTimesteps + + +class MultiValueEncoder(torch.nn.Module): + def __init__(self, encoders=()): + super().__init__() + if not isinstance(encoders, list): + encoders = [encoders] + self.encoders = torch.nn.ModuleList(encoders) + + def __call__(self, values, dtype): + emb = [] + for encoder, value in zip(self.encoders, values): + if value is not None: + value = value.unsqueeze(0) + emb.append(encoder(value, dtype)) + emb = torch.concat(emb, dim=0) + return emb + + +class SingleValueEncoder(torch.nn.Module): + def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None): + super().__init__() + self.prefer_len = prefer_len + self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) + self.prefer_value_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.positional_embedding = torch.nn.Parameter( + torch.randn(self.prefer_len, dim_out) + ) + + def forward(self, value, dtype): + value = value * 1000 + emb = self.prefer_proj(value).to(dtype) + emb = self.prefer_value_embedder(emb).squeeze(0) + base_embeddings = emb.expand(self.prefer_len, -1) + positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device) + learned_embeddings = base_embeddings + positional_embedding + return learned_embeddings + + @staticmethod + def state_dict_converter(): + return SingleValueEncoderStateDictConverter() + + +class SingleValueEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/general_modules.py b/diffsynth/models/general_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..216247c6c1f1c9fbba1e6a7f9631cfeec2684743 --- /dev/null +++ b/diffsynth/models/general_modules.py @@ -0,0 +1,139 @@ +import torch, math + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, + computation_device = None, + align_dtype_to_timestep = False, +): + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.device) + if align_dtype_to_timestep: + emb = emb.to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.computation_device = computation_device + self.scale = scale + self.align_dtype_to_timestep = align_dtype_to_timestep + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + computation_device=self.computation_device, + scale=self.scale, + align_dtype_to_timestep=self.align_dtype_to_timestep, + ) + return t_emb + + +class DiffusersCompatibleTimestepProj(torch.nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.linear_1 = torch.nn.Linear(dim_in, dim_out) + self.act = torch.nn.SiLU() + self.linear_2 = torch.nn.Linear(dim_out, dim_out) + + def forward(self, x): + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + return x + + +class TimestepEmbeddings(torch.nn.Module): + def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False): + super().__init__() + self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) + if diffusers_compatible_format: + self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out) + else: + self.timestep_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + + def forward(self, timestep, dtype): + time_emb = self.time_proj(timestep).to(dtype) + time_emb = self.timestep_embedder(time_emb) + return time_emb + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps, elementwise_affine=True): + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.ones((dim,))) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.to(input_dtype) + if self.weight is not None: + hidden_states = hidden_states * self.weight + return hidden_states + + +class AdaLayerNorm(torch.nn.Module): + def __init__(self, dim, single=False, dual=False): + super().__init__() + self.single = single + self.dual = dual + self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(torch.nn.functional.silu(emb)) + if self.single: + scale, shift = emb.unsqueeze(1).chunk(2, dim=2) + x = self.norm(x) * (1 + scale) + shift + return x + elif self.dual: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) + norm_x = self.norm(x) + x = norm_x * (1 + scale_msa) + shift_msa + norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d657238793ef1b42f28be86c707283bd2fdff4f --- /dev/null +++ b/diffsynth/models/longcat_video_dit.py @@ -0,0 +1,901 @@ +from typing import List, Optional, Tuple + +import math +import torch +import torch.nn as nn +import torch.amp as amp + +import numpy as np +import torch.nn.functional as F +from einops import rearrange, repeat +from .wan_video_dit import flash_attention +from ..core.gradient import gradient_checkpoint_forward + + +class RMSNorm_FP32(torch.nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class RotaryPositionalEmbedding(nn.Module): + + def __init__(self, + head_dim, + cp_split_hw=None + ): + """Rotary positional embedding for 3D + Reference : https://blog.eleuther.ai/rotary-embeddings/ + Paper: https://arxiv.org/pdf/2104.09864.pdf + Args: + dim: Dimension of embedding + base: Base value for exponential + """ + super().__init__() + self.head_dim = head_dim + assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.' + self.cp_split_hw = cp_split_hw + # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels + self.base = 10000 + self.freqs_dict = {} + + def register_grid_size(self, grid_size): + if grid_size not in self.freqs_dict: + self.freqs_dict.update({ + grid_size: self.precompute_freqs_cis_3d(grid_size) + }) + + def precompute_freqs_cis_3d(self, grid_size): + num_frames, height, width = grid_size + dim_t = self.head_dim - 4 * (self.head_dim // 6) + dim_h = 2 * (self.head_dim // 6) + dim_w = 2 * (self.head_dim // 6) + freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32) + grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32) + grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32) + grid_t = torch.from_numpy(grid_t).float() + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + # (T H W D) + freqs = rearrange(freqs, "T H W D -> (T H W) D") + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # with torch.no_grad(): + # freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width) + # freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw) + # freqs = rearrange(freqs, "T H W D -> (T H W) D") + + return freqs + + def forward(self, q, k, grid_size): + """3D RoPE. + + Args: + query: [B, head, seq, head_dim] + key: [B, head, seq, head_dim] + Returns: + query and key with the same shape as input. + """ + + if grid_size not in self.freqs_dict: + self.register_grid_size(grid_size) + + freqs_cis = self.freqs_dict[grid_size].to(q.device) + q_, k_ = q.float(), k.float() + freqs_cis = freqs_cis.float().to(q.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + q_ = (q_ * cos) + (rotate_half(q_) * sin) + k_ = (k_ * cos) + (rotate_half(k_) * sin) + + return q_.type_as(q), k_.type_as(k) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = None, + cp_split_hw: Optional[List[int]] = None + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + self.enable_bsa = enable_bsa + self.bsa_params = bsa_params + self.cp_split_hw = cp_split_hw + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.proj = nn.Linear(dim, dim) + + self.rope_3d = RotaryPositionalEmbedding( + self.head_dim, + cp_split_hw=cp_split_hw + ) + + def _process_attn(self, q, k, v, shape): + q = rearrange(q, "B H S D -> B S (H D)") + k = rearrange(k, "B H S D -> B S (H D)") + v = rearrange(v, "B H S D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads) + return x + + def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if return_kv: + k_cache, v_cache = k.clone(), v.clone() + + q, k = self.rope_3d(q, k, shape) + + # cond mode + if num_cond_latents is not None and num_cond_latents > 0: + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + # process the condition tokens + q_cond = q[:, :, :num_cond_latents_thw].contiguous() + k_cond = k[:, :, :num_cond_latents_thw].contiguous() + v_cond = v[:, :, :num_cond_latents_thw].contiguous() + x_cond = self._process_attn(q_cond, k_cond, v_cond, shape) + # process the noise tokens + q_noise = q[:, :, num_cond_latents_thw:].contiguous() + x_noise = self._process_attn(q_noise, k, v, shape) + # merge x_cond and x_noise + x = torch.cat([x_cond, x_noise], dim=2).contiguous() + else: + x = self._process_attn(q, k, v, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + if return_kv: + return x, (k_cache, v_cache) + else: + return x + + def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + T, H, W = shape + k_cache, v_cache = kv_cache + assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B] + if k_cache.shape[0] == 1: + k_cache = k_cache.repeat(B, 1, 1, 1) + v_cache = v_cache.repeat(B, 1, 1, 1) + + if num_cond_latents is not None and num_cond_latents > 0: + k_full = torch.cat([k_cache, k], dim=2).contiguous() + v_full = torch.cat([v_cache, v], dim=2).contiguous() + q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous() + q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) + q = q_padding[:, :, -N:].contiguous() + + x = self._process_attn(q, k_full, v_full, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads, + enable_flashattn3=False, + enable_flashattn2=False, + enable_xformers=False, + ): + super(MultiHeadCrossAttention, self).__init__() + assert dim % num_heads == 0, "d_model must be divisible by num_heads" + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = nn.Linear(dim, dim) + self.kv_linear = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + + def _process_cross_attn(self, x, cond, kv_seqlen): + B, N, C = x.shape + assert C == self.dim and cond.shape[2] == self.dim + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + q, k = self.q_norm(q), self.k_norm(k) + + q = rearrange(q, "B S H D -> B S (H D)") + k = rearrange(k, "B S H D -> B S (H D)") + v = rearrange(v, "B S H D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + + x = x.view(B, -1, C) + x = self.proj(x) + return x + + def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): + """ + x: [B, N, C] + cond: [B, M, C] + """ + if num_cond_latents is None or num_cond_latents == 0: + return self._process_cross_attn(x, cond, kv_seqlen) + else: + B, N, C = x.shape + if num_cond_latents is not None and num_cond_latents > 0: + assert shape is not None, "SHOULD pass in the shape" + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C] + output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C] + output = torch.cat([ + torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), + output_noise + ], dim=1).contiguous() + else: + raise NotImplementedError + + return output + + +class LayerNorm_FP32(nn.LayerNorm): + def __init__(self, dim, eps, elementwise_affine): + super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + out = F.layer_norm( + inputs.float(), + self.normalized_shape, + None if self.weight is None else self.weight.float(), + None if self.bias is None else self.bias.float() , + self.eps + ).to(origin_dtype) + return out + + +def modulate_fp32(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D) + # ensure the modulation params be fp32 + assert shift.dtype == torch.float32, scale.dtype == torch.float32 + dtype = x.dtype + x = norm_func(x.to(torch.float32)) + x = x * (scale + 1) + shift + x = x.to(dtype) + return x + + +class FinalLayer_FP32(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim): + super().__init__() + self.hidden_size = hidden_size + self.num_patch = num_patch + self.out_channels = out_channels + self.adaln_tembed_dim = adaln_tembed_dim + + self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True)) + + def forward(self, x, t, latent_shape): + # timestep shape: [B, T, C] + assert t.dtype == torch.float32 + B, N, C = x.shape + T, _, _ = latent_shape + + with amp.autocast('cuda', dtype=torch.float32): + shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] + x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) + x = self.linear(x) + return x + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.dim = dim + self.hidden_dim = hidden_dim + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, t_embed_dim, frequency_embedding_size=256): + super().__init__() + self.t_embed_dim = t_embed_dim + self.frequency_embedding_size = frequency_embedding_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, t_embed_dim, bias=True), + nn.SiLU(), + nn.Linear(t_embed_dim, t_embed_dim, bias=True), + ) + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, in_channels, hidden_size): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.y_proj = nn.Sequential( + nn.Linear(in_channels, hidden_size, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, caption): + B, _, N, C = caption.shape + caption = self.y_proj(caption) + return caption + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + B, C, T, H, W = x.shape + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class LongCatSingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: int, + adaln_tembed_dim: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params=None, + cp_split_hw=None + ): + super().__init__() + + self.hidden_size = hidden_size + + # scale and gate modulation + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True) + ) + + self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True) + + self.attn = Attention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + self.cross_attn = MultiHeadCrossAttention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + ) + self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio)) + + def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False): + """ + x: [B, N, C] + y: [1, N_valid_tokens, C] + t: [B, T, C_t] + y_seqlen: [B]; type of a list + latent_shape: latent shape of a single item + """ + x_dtype = x.dtype + + B, N, C = x.shape + T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. + + # compute modulation params in fp32 + with amp.autocast(device_type='cuda', dtype=torch.float32): + shift_msa, scale_msa, gate_msa, \ + shift_mlp, scale_mlp, gate_mlp = \ + self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] + + # self attn with modulation + x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C) + + if kv_cache is not None: + kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device)) + attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache) + else: + attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv) + + if return_kv: + x_s, kv_cache = attn_outputs + else: + x_s = attn_outputs + + with amp.autocast(device_type='cuda', dtype=torch.float32): + x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + # cross attn + if not skip_crs_attn: + if kv_cache is not None: + num_cond_latents = None + x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape) + + # ffn with modulation + x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) + x_s = self.ffn(x_m) + with amp.autocast(device_type='cuda', dtype=torch.float32): + x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + if return_kv: + return x, kv_cache + else: + return x + + +class LongCatVideoTransformer3DModel(torch.nn.Module): + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + hidden_size: int = 4096, + depth: int = 48, + num_heads: int = 32, + caption_channels: int = 4096, + mlp_ratio: int = 4, + adaln_tembed_dim: int = 512, + frequency_embedding_size: int = 256, + # default params + patch_size: Tuple[int] = (1, 2, 2), + # attention config + enable_flashattn3: bool = False, + enable_flashattn2: bool = True, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]}, + cp_split_hw: Optional[List[int]] = [1, 1], + text_tokens_zero_pad: bool = True, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.cp_split_hw = cp_split_hw + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + ) + + self.blocks = nn.ModuleList( + [ + LongCatSingleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + adaln_tembed_dim=adaln_tembed_dim, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer_FP32( + hidden_size, + np.prod(self.patch_size), + out_channels, + adaln_tembed_dim, + ) + + self.gradient_checkpointing = False + self.text_tokens_zero_pad = text_tokens_zero_pad + + self.lora_dict = {} + self.active_loras = [] + + def enable_loras(self, lora_key_list=[]): + self.disable_all_loras() + + module_loras = {} # {module_name: [lora1, lora2, ...]} + model_device = next(self.parameters()).device + model_dtype = next(self.parameters()).dtype + + for lora_key in lora_key_list: + if lora_key in self.lora_dict: + for lora in self.lora_dict[lora_key].loras: + lora.to(model_device, dtype=model_dtype, non_blocking=True) + module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") + if module_name not in module_loras: + module_loras[module_name] = [] + module_loras[module_name].append(lora) + self.active_loras.append(lora_key) + + for module_name, loras in module_loras.items(): + module = self._get_module_by_name(module_name) + if not hasattr(module, 'org_forward'): + module.org_forward = module.forward + module.forward = self._create_multi_lora_forward(module, loras) + + def _create_multi_lora_forward(self, module, loras): + def multi_lora_forward(x, *args, **kwargs): + weight_dtype = x.dtype + org_output = module.org_forward(x, *args, **kwargs) + + total_lora_output = 0 + for lora in loras: + if lora.use_lora: + lx = lora.lora_down(x.to(lora.lora_down.weight.dtype)) + lx = lora.lora_up(lx) + lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale + total_lora_output += lora_output + + return org_output + total_lora_output + + return multi_lora_forward + + def _get_module_by_name(self, module_name): + try: + module = self + for part in module_name.split('.'): + module = getattr(module, part) + return module + except AttributeError as e: + raise ValueError(f"Cannot find module: {module_name}, error: {e}") + + def disable_all_loras(self): + for name, module in self.named_modules(): + if hasattr(module, 'org_forward'): + module.forward = module.org_forward + delattr(module, 'org_forward') + + for lora_key, lora_network in self.lora_dict.items(): + for lora in lora_network.loras: + lora.to("cpu") + + self.active_loras.clear() + + def enable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = True + + def disable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = False + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask=None, + num_cond_latents=0, + return_kv=False, + kv_cache_dict={}, + skip_crs_attn=False, + offload_kv_cache=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + + B, _, T, H, W = hidden_states.shape + + N_t = T // self.patch_size[0] + N_h = H // self.patch_size[1] + N_w = W // self.patch_size[2] + + assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension." + + # expand the shape of timestep from [B] to [B, T] + if len(timestep.shape) == 1: + timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T] + timestep[:, :num_cond_latents] = 0 + + dtype = hidden_states.dtype + hidden_states = hidden_states.to(dtype) + timestep = timestep.to(dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype) + + hidden_states = self.x_embedder(hidden_states) # [B, N, C] + + with amp.autocast(device_type='cuda', dtype=torch.float32): + t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] + + encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] + + if self.text_tokens_zero_pad and encoder_attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None] + encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype) + + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1) + encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C] + y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B] + else: + y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1]) + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w) + # hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw) + # hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C") + + # blocks + kv_cache_dict_ret = {} + for i, block in enumerate(self.blocks): + block_outputs = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=hidden_states, + y=encoder_hidden_states, + t=t, + y_seqlen=y_seqlens, + latent_shape=(N_t, N_h, N_w), + num_cond_latents=num_cond_latents, + return_kv=return_kv, + kv_cache=kv_cache_dict.get(i, None), + skip_crs_attn=skip_crs_attn, + ) + + if return_kv: + hidden_states, kv_cache = block_outputs + if offload_kv_cache: + kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu()) + else: + kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous()) + else: + hidden_states = block_outputs + + hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out] + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw) + + hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W] + + # cast to float32 for better accuracy + hidden_states = hidden_states.to(torch.float32) + + if return_kv: + return hidden_states, kv_cache_dict_ret + else: + return hidden_states + + + def unpatchify(self, x, N_t, N_h, N_w): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + @staticmethod + def state_dict_converter(): + return LongCatVideoTransformer3DModelDictConverter() + + +class LongCatVideoTransformer3DModelDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..16d72ddba8abb44dda24a03122d9fff79dd42660 --- /dev/null +++ b/diffsynth/models/model_loader.py @@ -0,0 +1,111 @@ +from ..core.loader import load_model, hash_model_file +from ..core.vram import AutoWrappedModule +from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS +import importlib, json, torch + + +class ModelPool: + def __init__(self): + self.model = [] + self.model_name = [] + self.model_path = [] + + def import_model_class(self, model_class): + split = model_class.rfind(".") + model_resource, model_class = model_class[:split], model_class[split+1:] + model_class = importlib.import_module(model_resource).__getattribute__(model_class) + return model_class + + def need_to_enable_vram_management(self, vram_config): + return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None + + def fetch_module_map(self, model_class, vram_config): + if self.need_to_enable_vram_management(vram_config): + if model_class in VRAM_MANAGEMENT_MODULE_MAPS: + module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()} + else: + module_map = {self.import_model_class(model_class): AutoWrappedModule} + else: + module_map = None + return module_map + + def load_model_file(self, config, path, vram_config, vram_limit=None): + model_class = self.import_model_class(config["model_class"]) + model_config = config.get("extra_kwargs", {}) + if "state_dict_converter" in config: + state_dict_converter = self.import_model_class(config["state_dict_converter"]) + else: + state_dict_converter = None + module_map = self.fetch_module_map(config["model_class"], vram_config) + model = load_model( + model_class, path, model_config, + vram_config["computation_dtype"], vram_config["computation_device"], + state_dict_converter, + use_disk_map=True, + vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, + ) + return model + + def default_vram_config(self): + vram_config = { + "offload_dtype": None, + "offload_device": None, + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cpu", + "computation_dtype": torch.bfloat16, + "computation_device": "cpu", + } + return vram_config + + def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): + print(f"Loading models from: {json.dumps(path, indent=4)}") + if vram_config is None: + vram_config = self.default_vram_config() + model_hash = hash_model_file(path) + loaded = False + for config in MODEL_CONFIGS: + if config["model_hash"] == model_hash: + model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) + if clear_parameters: self.clear_parameters(model) + self.model.append(model) + model_name = config["model_name"] + self.model_name.append(model_name) + self.model_path.append(path) + model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")} + print(f"Loaded model: {json.dumps(model_info, indent=4)}") + loaded = True + if not loaded: + raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}") + + def fetch_model(self, model_name, index=None): + fetched_models = [] + fetched_model_paths = [] + for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): + if model_name == model_name_: + fetched_models.append(model) + fetched_model_paths.append(model_path) + if len(fetched_models) == 0: + print(f"No {model_name} models available. This is not an error.") + model = None + elif len(fetched_models) == 1: + print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") + model = fetched_models[0] + else: + if index is None: + model = fetched_models[0] + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") + elif isinstance(index, int): + model = fetched_models[:index] + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.") + else: + model = fetched_models + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.") + return model + + def clear_parameters(self, model: torch.nn.Module): + for name, module in model.named_children(): + self.clear_parameters(module) + for name, param in model.named_parameters(recurse=False): + setattr(model, name, None) diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..011039842312b01a0bbb69b999bc868902736e9a --- /dev/null +++ b/diffsynth/models/nexus_gen.py @@ -0,0 +1,161 @@ +import torch +from PIL import Image + + +class NexusGenAutoregressiveModel(torch.nn.Module): + def __init__(self, max_length=1024, max_pixels=262640): + super(NexusGenAutoregressiveModel, self).__init__() + from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration + from transformers import Qwen2_5_VLConfig + self.max_length = max_length + self.max_pixels = max_pixels + model_config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLForConditionalGeneration(model_config) + self.processor = None + + + def load_processor(self, path): + from .nexus_gen_ar_model import Qwen2_5_VLProcessor + self.processor = Qwen2_5_VLProcessor.from_pretrained(path) + + + @staticmethod + def state_dict_converter(): + return NexusGenAutoregressiveModelStateDictConverter() + + def bound_image(self, image, max_pixels=262640): + from qwen_vl_utils import smart_resize + resized_height, resized_width = smart_resize( + image.height, + image.width, + max_pixels=max_pixels, + ) + return image.resize((resized_width, resized_height)) + + def get_editing_msg(self, instruction): + if '' not in instruction: + instruction = ' ' + instruction + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: "}] + return messages + + def get_generation_msg(self, instruction): + instruction = "Generate an image according to the following description: {}".format(instruction) + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] + return messages + + def forward(self, instruction, ref_image=None, num_img_tokens=81): + """ + Generate target embeddings for the given instruction and reference image. + """ + if ref_image is not None: + messages = self.get_editing_msg(instruction) + images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + else: + messages = self.get_generation_msg(instruction) + images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + + return output_image_embeddings + + def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81): + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + text = text.replace('', '<|vision_start|><|image_pad|><|vision_end|>') + inputs = processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + + input_embeds = model.model.embed_tokens(inputs['input_ids']) + image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + image_prefill_embeds = model.image_prefill_embeds( + torch.arange(81, device=model.device).long() + ) + input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) + + position_ids, _ = model.get_rope_index( + inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + return output_image_embeddings, input_image_embeds, inputs['image_grid_thw'] + + +class NexusGenAutoregressiveModelStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict = {"model." + key: value for key, value in state_dict.items()} + return state_dict diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a29735e1ca71fcbf7847928a309e7799718ec7 --- /dev/null +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -0,0 +1,1143 @@ +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput +from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.modeling_outputs import ModelOutput +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + QWEN2_5_VL_INPUTS_DOCSTRING, + ) + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + image_embeddings: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + self.image_prefill_embeds = nn.Embedding(81, config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + token_loss_weight: Optional[float] = 0.1, + img_loss_weight: Optional[float] = 1.0, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # test feature + inputs_embeds = self.model.embed_tokens(input_ids) + # for image encoding and training + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + # position_ids [3, B, L] + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + image_embeds = self.vision_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + # prepare labels for logits + logits_labels = labels.clone().detach() + image_tokens = (labels == self.config.image_token_id) + logits_labels[image_tokens] = -100 + + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = logits_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) * token_loss_weight + + shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1) + shifted_image_embeds = image_embeds[:, :-1, :].contiguous() # (B, L-1, D) + masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d] # (num_image_tokens, D) + + mse_loss_fct = nn.MSELoss() + mse_loss_fct = mse_loss_fct.to(shift_logits.device) + if image_embeddings is None: + image_embeddings = torch.zeros_like(masked_image_embeds) + img_loss = mse_loss_fct(masked_image_embeds, image_embeddings) + + cos_sim = torch.cosine_similarity( + masked_image_embeds, + image_embeddings, + dim=-1 + ) + cos_loss = (1 - cos_sim).mean() + img_loss = 0.5 * img_loss + 0.5 * cos_loss + # fix nan for empty image tokens + if image_embeddings.size(0) == 0: + img_loss = img_loss.nan_to_num(0.0) + # combine the loss + loss = loss + img_loss_weight * img_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + image_embeddings=image_embeds, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + model_forward = self.__call__ + if isinstance(model_kwargs.get("past_key_values"), Cache): + is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache + is_compileable = is_compileable and not self.generation_config.disable_compile + if is_compileable and ( + self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + ): + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + is_prefill = True + is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id + generation_image_grid_thw = model_kwargs.pop("generation_image_grid_thw", self.get_default_image_grid_thw()) + num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) + output_image_embeddings = [] + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare prefilled embeds + model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs)) + + # parse position_ids from model_kwargs + model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs)) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + # TODO: support batch image sampling + if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens: + output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1)) + + if synced_gpus and this_peer_finished: + continue + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + + # do not sample token + next_token_logits[:, self.config.vision_end_token_id] = -float('inf') + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id): + # probs[:, self.config.vision_end_token_id] = 0 + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + #TODO: support batch image sample + if num_img_tokens is not None: + cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1) + # check whether is sampling images + is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img) + is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens) + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to end sampling images + next_tokens[is_end_img] = self.config.vision_end_token_id + else: + # check whether to end sampling images + is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id)) + # replace the next token with the image token if is sampling image + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to start sampling images + is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id)) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + # output the image embeddings + output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None + + if return_dict_in_generate: + return GenerateDecoderOnlyAll2AllOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + output_image_embeddings=output_image_embeddings, + ) + else: + return input_ids + + + def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs): + if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img): + return {} + # TODO: support batch image sample + image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0) + inputs_embeds = self.image_prefill_embeds(image_idx) + return {"inputs_embeds": inputs_embeds} + + + def get_default_image_grid_thw(self,): + return torch.tensor([[1, 18, 18]]).to(self.device) + + + def get_num_image_tokens(self, image_grid_thw): + return int(torch.prod(image_grid_thw, dim=1).sum() // 4) + + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + num_img_tokens = model_kwargs.pop("generation_image_grid_thw", None) + super()._validate_model_kwargs(model_kwargs) + model_kwargs["generation_image_grid_thw"] = num_img_tokens + + def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs): + # Overwritten -- prepare position_ids for image tokens + cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)) + # TODO: support batch image sample + if cur_img_tokens > 0 and bool(is_sampling_img): + image_grid_thw = generation_image_grid_thw + if model_kwargs.get('image_grid_thw') is not None: + image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw]) + remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens + padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id) + padded_ids = torch.cat([input_ids, padding_ids], dim=1) + position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None) + if model_kwargs.get("use_cache", True): + position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1) + else: + position_ids = position_ids[:, :, :input_ids.shape[1]] + return {"position_ids": position_ids} + return {} + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + image_embeddings=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepared with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] + + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def batch_decode_all2all(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + decoded = self.tokenizer.batch_decode(*args, **kwargs) + pattern = r'<\|vision_start\|>.*?<\|vision_end\|>' + decoded_with_image_tag = [re.sub(pattern, '', d, flags=re.DOTALL) for d in decoded] + decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag] + return decoded_with_image_tag + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] diff --git a/diffsynth/models/nexus_gen_projector.py b/diffsynth/models/nexus_gen_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..d69b3e1bfd50fc3b9c098f7775afae4020f9b320 --- /dev/null +++ b/diffsynth/models/nexus_gen_projector.py @@ -0,0 +1,417 @@ +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple + + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + from transformers.modeling_rope_utils import _compute_default_rope_parameters + self.rope_init_fn = _compute_default_rope_parameters + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NexusGenImageEmbeddingMerger(nn.Module): + def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): + super().__init__() + from transformers import Qwen2_5_VLConfig + from transformers.activations import ACT2FN + config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.config = config + self.num_layers = num_layers + self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)]) + self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, out_channel * expand_ratio), + Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps), + ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel), + Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps)) + self.base_grid = torch.tensor([[1, 72, 72]], device=device) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device) + + def get_position_ids(self, image_grid_thw): + """ + Generates position ids for the input embeddings grid. + modified from the qwen2_vl mrope. + """ + batch_size = image_grid_thw.shape[0] + spatial_merge_size = self.config.vision_config.spatial_merge_size + t, h, w = ( + image_grid_thw[0][0], + image_grid_thw[0][1], + image_grid_thw[0][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + scale_h = self.base_grid[0][1].item() / h.item() + scale_w = self.base_grid[0][2].item() / w.item() + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + time_tensor = expanded_range * self.config.vision_config.tokens_per_second + t_index = time_tensor.long().flatten().to(image_grid_thw.device) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w + # 3, B, L + position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2) + return position_ids + + def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None): + position_ids = self.get_position_ids(embeds_grid) + hidden_states = embeds + if ref_embeds is not None: + position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid) + position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1) + hidden_states = torch.cat((embeds, ref_embeds), dim=1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + + hidden_states = self.projector(hidden_states) + return hidden_states + + @staticmethod + def state_dict_converter(): + return NexusGenMergerStateDictConverter() + + +class NexusGenMergerStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')} + return merger_state_dict + + +class NexusGenAdapter(nn.Module): + """ + Adapter for Nexus-Gen generation decoder. + """ + def __init__(self, input_dim=3584, output_dim=4096): + super(NexusGenAdapter, self).__init__() + self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim), + nn.LayerNorm(output_dim), nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.LayerNorm(output_dim)) + + def forward(self, x): + return self.adapter(x) + + @staticmethod + def state_dict_converter(): + return NexusGenAdapterStateDictConverter() + + +class NexusGenAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')} + return adapter_state_dict diff --git a/diffsynth/models/qwen_image_controlnet.py b/diffsynth/models/qwen_image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce40809065b3eac020d2b1da29101681a44764a --- /dev/null +++ b/diffsynth/models/qwen_image_controlnet.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from .general_modules import RMSNorm + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072): + super().__init__() + self.x_rms = RMSNorm(dim, eps=1e-6) + self.y_rms = RMSNorm(dim, eps=1e-6) + self.input_proj = nn.Linear(dim, dim) + self.act = nn.GELU() + self.output_proj = nn.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + def init_weights(self): + # zero initialize output_proj + nn.init.zeros_(self.output_proj.weight) + nn.init.zeros_(self.output_proj.bias) + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + ): + super().__init__() + self.img_in = nn.Linear(in_dim + additional_in_dim, dim) + self.controlnet_blocks = nn.ModuleList( + [ + BlockWiseControlBlock(dim) + for _ in range(num_layers) + ] + ) + + def init_weight(self): + nn.init.zeros_(self.img_in.weight) + nn.init.zeros_(self.img_in.bias) + for block in self.controlnet_blocks: + block.init_weights() + + def process_controlnet_conditioning(self, controlnet_conditioning): + return self.img_in(controlnet_conditioning) + + def blockwise_forward(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..ac54945a94b1d13b29748585eb43db9420fa5765 --- /dev/null +++ b/diffsynth/models/qwen_image_dit.py @@ -0,0 +1,533 @@ +import torch, math +import torch.nn as nn +from typing import Tuple, Optional, Union, List +from einops import rearrange +from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + + +def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False): + if FLASH_ATTN_3_AVAILABLE and attention_mask is None: + if not enable_fp8_attention: + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + origin_dtype = q.dtype + q_std, k_std, v_std = q.std(), k.std(), v.std() + q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn) + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1))) + if isinstance(x, tuple): + x = x[0] + x = x.to(origin_dtype) * v_std + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] +): + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + self.rope_cache = {} + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer( + index, + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + + def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): + if isinstance(video_fhw, list): + video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3)) + _, height, width = video_fhw + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + required_len = max_vid_index + max(txt_seq_lens) + cur_max_len = self.pos_freqs.shape[0] + if required_len <= cur_max_len: + return + + new_max_len = math.ceil(required_len / 512) * 512 + pos_index = torch.arange(new_max_len) + neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + return + + + def forward(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + + def forward_sampling(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache: + frame_0, height_0, width_0 = video_fhw[0] + + rope_key_0 = f"0_{height_0}_{width_0}" + spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1) + h_indices = torch.linspace(0, height_0 - 1, height).long() + w_indices = torch.linspace(0, width_0 - 1, width).long() + h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij') + sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :] + + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame + + seq_lens = frame * height * width + self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone() + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone() + vid_freqs.append(self.rope_cache[rope_key].contiguous()) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + +class QwenFeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dropout: float = 0.0, + ): + super().__init__() + inner_dim = int(dim * 4) + self.net = nn.ModuleList([]) + self.net.append(ApproximateGELU(dim, inner_dim)) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class QwenDoubleStreamAttention(nn.Module): + def __init__( + self, + dim_a, + dim_b, + num_heads, + head_dim, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = nn.Linear(dim_a, dim_a) + self.to_k = nn.Linear(dim_a, dim_a) + self.to_v = nn.Linear(dim_a, dim_a) + self.norm_q = RMSNorm(head_dim, eps=1e-6) + self.norm_k = RMSNorm(head_dim, eps=1e-6) + + self.add_q_proj = nn.Linear(dim_b, dim_b) + self.add_k_proj = nn.Linear(dim_b, dim_b) + self.add_v_proj = nn.Linear(dim_b, dim_b) + self.norm_added_q = RMSNorm(head_dim, eps=1e-6) + self.norm_added_k = RMSNorm(head_dim, eps=1e-6) + + self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a)) + self.to_add_out = nn.Linear(dim_b, dim_b) + + def forward( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + enable_fp8_attention: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) + txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) + seq_txt = txt_q.shape[1] + + img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads) + img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads) + img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads) + + txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads) + txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads) + txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads) + + img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) + txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) + + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + + joint_q = torch.cat([txt_q, img_q], dim=2) + joint_k = torch.cat([txt_k, img_k], dim=2) + joint_v = torch.cat([txt_v, img_v], dim=2) + + joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype) + + txt_attn_output = joint_attn_out[:, :seq_txt, :] + img_attn_output = joint_attn_out[:, seq_txt:, :] + + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim), + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = QwenDoubleStreamAttention( + dim_a=dim, + dim_b=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) + + def _modulate(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + image: torch.Tensor, + text: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + enable_fp8_attention = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + + img_normed = self.img_norm1(image) + img_modulated, img_gate = self._modulate(img_normed, img_mod_attn) + + txt_normed = self.txt_norm1(text) + txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) + + img_attn_out, txt_attn_out = self.attn( + image=img_modulated, + text=txt_modulated, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, + ) + + image = image + img_gate * img_attn_out + text = text + txt_gate * txt_attn_out + + img_normed_2 = self.img_norm2(image) + img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp) + + txt_normed_2 = self.txt_norm2(text) + txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + + img_mlp_out = self.img_mlp(img_modulated_2) + txt_mlp_out = self.txt_mlp(txt_modulated_2) + + image = image + img_gate_2 * img_mlp_out + text = text + txt_gate_2 * txt_mlp_out + + return text, image + + +class QwenImageDiT(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + ): + super().__init__() + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + + self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) + self.txt_norm = RMSNorm(3584, eps=1e-6) + + self.img_in = nn.Linear(64, 3072) + self.txt_in = nn.Linear(3584, 3072) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = AdaLayerNorm(3072, single=True) + self.proj_out = nn.Linear(3072, 64) + + + def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes): + # prompt_emb + all_prompt_emb = entity_prompt_emb + [prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] + all_prompt_emb = torch.cat(all_prompt_emb, dim=1) + + # image_rotary_emb + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask] + entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] + txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + + # attention_mask + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype) + entity_masks = entity_masks + [global_mask] + + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + total_seq_len = sum(seq_lens) + image.shape[1] + patched_masks = [] + for i in range(N): + patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + patched_masks.append(patched_mask) + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + # prompt-image attention mask + image_start = sum(seq_lens) + image_end = total_seq_len + cumsum = [0] + single_image_seq = image_end - image_start + for length in seq_lens: + cumsum.append(cumsum[-1] + length) + for i in range(N): + prompt_start = cumsum[i] + prompt_end = cumsum[i+1] + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) + # repeat image mask to match the single image sequence length + repeat_time = single_image_seq // image_mask.shape[-1] + image_mask = image_mask.repeat(1, 1, repeat_time) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt attention mask, let the prompt tokens not attend to each other + for i in range(N): + for j in range(N): + if i == j: + continue + start_i, end_i = cumsum[i], cumsum[i+1] + start_j, end_j = cumsum[j], cumsum[j+1] + attention_mask[:, start_i:end_i, start_j:end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + + return all_prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + ): + img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + + image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image = self.img_in(image) + text = self.txt_in(self.txt_norm(prompt_emb)) + + conditioning = self.time_text_embed(timestep, image.dtype) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + + for block in self.transformer_blocks: + text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + ) + + image = self.norm_out(image, conditioning) + image = self.proj_out(image) + + latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) + return image diff --git a/diffsynth/models/qwen_image_image2lora.py b/diffsynth/models/qwen_image_image2lora.py new file mode 100644 index 0000000000000000000000000000000000000000..6aefbf25de6ccdb37de2d2d44e644fb77952b570 --- /dev/null +++ b/diffsynth/models/qwen_image_image2lora.py @@ -0,0 +1,128 @@ +import torch + + +class CompressedMLP(torch.nn.Module): + def __init__(self, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias) + + def forward(self, x, residual=None): + x = self.proj_in(x) + if residual is not None: x = x + residual + x = self.proj_out(x) + return x + + +class ImageEmbeddingToLoraMatrix(torch.nn.Module): + def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank): + super().__init__() + self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank) + self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank) + self.lora_a_dim = lora_a_dim + self.lora_b_dim = lora_b_dim + self.rank = rank + + def forward(self, x, residual=None): + lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim) + lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank) + return lora_a, lora_b + + +class SequencialMLP(torch.nn.Module): + def __init__(self, length, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias) + self.length = length + self.in_dim = in_dim + self.mid_dim = mid_dim + + def forward(self, x): + x = x.view(self.length, self.in_dim) + x = self.proj_in(x) + x = x.view(1, self.length * self.mid_dim) + x = self.proj_out(x) + return x + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class QwenImageImage2LoRAModel(torch.nn.Module): + def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = [ + [ + ("attn.to_q", 3072, 3072), + ("attn.to_k", 3072, 3072), + ("attn.to_v", 3072, 3072), + ("attn.to_out.0", 3072, 3072), + ], + [ + ("img_mlp.net.2", 3072*4, 3072), + ("img_mod.1", 3072, 3072*6), + ], + [ + ("attn.add_q_proj", 3072, 3072), + ("attn.add_k_proj", 3072, 3072), + ("attn.add_v_proj", 3072, 3072), + ("attn.to_add_out", 3072, 3072), + ], + [ + ("txt_mlp.net.2", 3072*4, 3072), + ("txt_mod.1", 3072, 3072*6), + ], + ] + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) diff --git a/diffsynth/models/qwen_image_text_encoder.py b/diffsynth/models/qwen_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f19d2d8ae2a61fd9cc45414a6bac18d28e4edcc9 --- /dev/null +++ b/diffsynth/models/qwen_image_text_encoder.py @@ -0,0 +1,190 @@ +import torch +from typing import Optional, Union + + +class QwenImageTextEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel + config = Qwen2_5_VLConfig(**{ + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "text_config": { + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": None, + "initializer_range": 0.02, + "intermediate_size": 18944, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl_text", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": None, + "torch_dtype": "float32", + "use_cache": True, + "use_sliding_window": False, + "video_token_id": None, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }, + "tie_word_embeddings": False, + "torch_dtype": "float32", + "transformers_version": "4.54.0", + "use_cache": True, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "depth": 32, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 3420, + "model_type": "qwen2_5_vl", + "num_heads": 16, + "out_hidden_size": 3584, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + "tokens_per_second": 2, + "torch_dtype": "float32", + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLModel(config) + self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.config = config + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ): + output_attentions = False + output_hidden_states = True + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states diff --git a/diffsynth/models/qwen_image_vae.py b/diffsynth/models/qwen_image_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..cb047131058c703da701ca9417f71a3eba94d1ea --- /dev/null +++ b/diffsynth/models/qwen_image_vae.py @@ -0,0 +1,723 @@ +import torch +from typing import List, Optional, Tuple, Union +from torch import nn + + +CACHE_T = 2 + +class QwenImageCausalConv3d(torch.nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = torch.nn.functional.pad(x, padding) + return super().forward(x) + + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = torch.nn.SiLU() + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = torch.nn.SiLU() + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = torch.nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = torch.nn.SiLU() + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + + +class QwenImageVAE(torch.nn.Module): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1) + self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1) + + def encode(self, x, **kwargs): + x = x.unsqueeze(2) + x = self.encoder(x) + x = self.quant_conv(x) + x = x[:, :16] + mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device) + x = (x - mean) * std + x = x.squeeze(2) + return x + + def decode(self, x, **kwargs): + x = x.unsqueeze(2) + mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device) + x = x / std + mean + x = self.post_quant_conv(x) + x = self.decoder(x) + x = x.squeeze(2) + return x diff --git a/diffsynth/models/sd_text_encoder.py b/diffsynth/models/sd_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a1171c265a43048c1623bd0c2375f4fc3f5e5d --- /dev/null +++ b/diffsynth/models/sd_text_encoder.py @@ -0,0 +1,412 @@ +import torch +from .attention import Attention +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..cd85adbb72542536de8f80043046997706e213e7 --- /dev/null +++ b/diffsynth/models/siglip2_image_encoder.py @@ -0,0 +1,70 @@ +from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig +from transformers import SiglipImageProcessor +import torch + + +class Siglip2ImageEncoder(SiglipVisionTransformer): + def __init__(self): + config = SiglipVisionConfig( + attention_dropout = 0.0, + dtype = "float32", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1536, + image_size = 384, + intermediate_size = 6144, + layer_norm_eps = 1e-06, + model_type = "siglip_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 40, + patch_size = 16, + transformers_version = "4.56.1", + _attn_implementation = "sdpa" + ) + super().__init__(config) + self.processor = SiglipImageProcessor( + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.5, + 0.5, + 0.5 + ], + image_processor_type = "SiglipImageProcessor", + image_std = [ + 0.5, + 0.5, + 0.5 + ], + processor_class = "SiglipProcessor", + resample = 2, + rescale_factor = 0.00392156862745098, + size = { + "height": 384, + "width": 384 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] + pixel_values = pixel_values.to(device=device, dtype=torch_dtype) + output_attentions = False + output_hidden_states = False + interpolate_pos_encoding = False + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return pooler_output \ No newline at end of file diff --git a/diffsynth/models/step1x_connector.py b/diffsynth/models/step1x_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..225c8fbcb54f8daebf48656141e9a8998d002fd8 --- /dev/null +++ b/diffsynth/models/step1x_connector.py @@ -0,0 +1,663 @@ +from typing import Optional + +import torch, math +import torch.nn +from einops import rearrange +from torch import nn +from functools import partial +from einops import rearrange + + + +def attention(q, k, v, attn_mask, mode="torch"): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "b n s d -> b s (n d)") + return x + + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, device=device, dtype=dtype) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear( + in_features=in_channels, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + self.act_1 = act_layer() + self.linear_2 = nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, hidden_size, bias=True, **factory_kwargs + ), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore + nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding( + t, self.frequency_embedding_size, self.max_period + ).type(t.dtype) # type: ignore + t_emb = self.mlp(t_freq) + return t_emb + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + +class IndividualTokenRefinerBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + + if self.need_CA: + self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs,) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + if self.need_CA: + x = self.cross_attnblock(x, c, attn_mask, y) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + + + +class CrossAttnBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.norm1_2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_q = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + self.self_attn_kv = nn.Linear( + hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor=None, + + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + norm_y = self.norm1_2(y) + q = self.self_attn_q(norm_x) + q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) + kv = self.self_attn_kv(norm_y) + k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + return x + + + +class IndividualTokenRefiner(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=self.need_CA, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + y:torch.Tensor=None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( + 1, 1, seq_len, 1 + ) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + + for block in self.blocks: + x = block(x, c, self_attn_mask,y) + + return x + + +class SingleTokenRefiner(torch.nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + self.need_CA = need_CA + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + if self.need_CA: + self.input_embedder_CA = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection( + in_channels, hidden_size, act_layer, **factory_kwargs + ) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=need_CA, + **factory_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + y: torch.LongTensor=None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + if self.need_CA: + y = self.input_embedder_CA(y) + x = self.individual_token_refiner(x, c, mask, y) + else: + x = self.individual_token_refiner(x, c, mask) + + return x + + +class Qwen2Connector(torch.nn.Module): + def __init__( + self, + # biclip_dim=1024, + in_channels=3584, + hidden_size=4096, + heads_num=32, + depth=2, + need_CA=False, + device=None, + dtype=torch.bfloat16, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype":dtype} + + self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs) + self.global_proj_out=nn.Linear(in_channels,768) + + self.scale_factor = nn.Parameter(torch.zeros(1)) + with torch.no_grad(): + self.scale_factor.data += -(1 - 0.09) + + def forward(self, x,t,mask): + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + x_mean = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device)) + + global_out=self.global_proj_out(x_mean) + encoder_hidden_states = self.S(x,t,mask) + return encoder_hidden_states,global_out diff --git a/diffsynth/models/step1x_text_encoder.py b/diffsynth/models/step1x_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d0fe22157e0938b0821c5a7dc5d0376d2bc35b6a --- /dev/null +++ b/diffsynth/models/step1x_text_encoder.py @@ -0,0 +1,194 @@ +import torch +from typing import Optional, Union +from .qwen_image_text_encoder import QwenImageTextEncoder + + +class Step1xEditEmbedder(torch.nn.Module): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): + super().__init__() + self.max_length = max_length + self.dtype = dtype + self.device = device + + Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + self.prefix = Qwen25VL_7b_PREFIX + self.model = model + self.processor = processor + + def model_forward( + self, + model: QwenImageTextEncoder, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states + ) + + outputs = model.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states + + def forward(self, caption, ref_images): + text_list = caption + embs = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + masks = torch.zeros( + len(text_list), + self.max_length, + dtype=torch.long, + device=torch.cuda.current_device(), + ) + + def split_string(s): + s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): + + messages = [{"role": "user", "content": []}] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + + messages[0]["content"].append({"type": "image", "image": imgs}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs = [imgs] + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = ( + torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) + .unsqueeze(0) + .to("cuda") + ) + inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + outputs = self.model_forward( + self.model, + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values.to("cuda"), + image_grid_thw=inputs.image_grid_thw.to("cuda"), + output_hidden_states=True, + ) + + emb = outputs[-1] + + embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ + : self.max_length + ] + + masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( + (min(self.max_length, emb.shape[1] - 217)), + dtype=torch.long, + device=torch.cuda.current_device(), + ) + + return embs, masks diff --git a/diffsynth/models/wan_video_animate_adapter.py b/diffsynth/models/wan_video_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..3ace70d8b3162e77844f0957cd40207a54e674a9 --- /dev/null +++ b/diffsynth/models/wan_video_animate_adapter.py @@ -0,0 +1,650 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import Tuple, Optional, List +from einops import rearrange + + + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) H N D") + v = rearrange(v, "B L N H D -> (B L) H N D") + + q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp) + # Compute attention. + attn = F.scaled_dot_product_attention(q, k, v) + + attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.kernel = torch.nn.Parameter(kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + motion = self.dec.direction(motion_feat) + return motion + + +class WanAnimateAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5) + self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4) + + def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec + + def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): + if block_idx % 5 == 0: + adapter_args = [x, motion_vec, motion_masks, False] + residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) + x = residual_out + x + return x diff --git a/diffsynth/models/wan_video_camera_controller.py b/diffsynth/models/wan_video_camera_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..45a44ee6bcd408d7ee9d18653f933151ce351a72 --- /dev/null +++ b/diffsynth/models/wan_video_camera_controller.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +import os +from typing_extensions import Literal + +class SimpleAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): + super(SimpleAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + def process_camera_coordinates( + self, + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], + length: int, + height: int, + width: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + ): + if origin is None: + origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + coordinates = generate_camera_coordinates(direction, length, speed, origin) + plucker_embedding = process_pose_file(coordinates, width, height) + return plucker_embedding + + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def custom_meshgrid(*args): + # torch>=2.0.0 only + return torch.meshgrid(*args, indexing='ij') + + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.linalg.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + + +def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + + + +def generate_camera_coordinates( + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"], + length: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) +): + coordinates = [list(origin)] + while len(coordinates) < length: + coor = coordinates[-1].copy() + if "Left" in direction: + coor[9] += speed + if "Right" in direction: + coor[9] -= speed + if "Up" in direction: + coor[13] += speed + if "Down" in direction: + coor[13] -= speed + if "In" in direction: + coor[18] -= speed + if "Out" in direction: + coor[18] += speed + coordinates.append(coor) + return coordinates diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..daafa7a6876a34c645a36b7d13ac582455eb6603 --- /dev/null +++ b/diffsynth/models/wan_video_dit.py @@ -0,0 +1,406 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional +from einops import rearrange +from .wan_video_camera_controller import SimpleAdapter +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x,tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention( + dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): + x = self.patch_embedding(x) + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) + return x + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x diff --git a/diffsynth/models/wan_video_dit_instance.py b/diffsynth/models/wan_video_dit_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..40a9568096d8f4e391ae0e48e66b306b5fca0310 --- /dev/null +++ b/diffsynth/models/wan_video_dit_instance.py @@ -0,0 +1,779 @@ +""" +Wan Video DiT with instance-aware control (T5 semantics + bbox/mask). + +This refactor keeps the original Wan DiT backbone while integrating: +- Instance tokens: ` is ` text (T5) + instance_id embedding. +- Mask-guided cross-attention: per-patch gating via bbox- or mask-projected hints. +- Backward compatibility: still accepts id-based class/state embeddings and pixel masks. +""" + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .wan_video_camera_controller import SimpleAdapter + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +# ----------------------------------------------------------------------------- +# Common utils +# ----------------------------------------------------------------------------- +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode: bool = False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +# ----------------------------------------------------------------------------- +# Core blocks +# ----------------------------------------------------------------------------- +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + + +class MaskGuidedCrossAttention(nn.Module): + """ + 每个 patch 只关注覆盖它的实例 token,使用 log-mask trick 保证 0 区域被强制屏蔽。 + """ + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(dim, dim, bias=False) + self.to_v = nn.Linear(dim, dim, bias=False) + + self.to_out = nn.Linear(dim, dim) + self.norm = nn.LayerNorm(dim, eps=eps) + self.gate = nn.Parameter(torch.zeros(1)) # zero-init for stability + + def _attend(self, x: torch.Tensor, instance_tokens: torch.Tensor, instance_masks: torch.Tensor) -> torch.Tensor: + B, L, _ = x.shape + _, N, _ = instance_tokens.shape + if N == 0: + return x + if instance_masks.shape != (B, N, L): + raise ValueError(f"instance_masks shape mismatch, expect (B,N,L)=({B},{N},{L}), got {tuple(instance_masks.shape)}") + + h = self.num_heads + q = rearrange(self.to_q(self.norm(x)), "b l (h d) -> b h l d", h=h) + k = rearrange(self.to_k(instance_tokens), "b n (h d) -> b h n d", h=h) + v = rearrange(self.to_v(instance_tokens), "b n (h d) -> b h n d", h=h) + sim = torch.einsum("b h l d, b h n d -> b h l n", q, k) * self.scale + + mask_bias = instance_masks.transpose(1, 2).unsqueeze(1).to(dtype=sim.dtype) + sim = sim + torch.log(mask_bias.clamp(min=1e-6)) + attn = sim.softmax(dim=-1) + out = torch.einsum("b h l n, b h n d -> b h l d", attn, v) + out = rearrange(out, "b h l d -> b l (h d)") + return x + self.gate * self.to_out(out) + + def forward(self, x: torch.Tensor, instance_tokens: torch.Tensor, instance_masks: torch.Tensor) -> torch.Tensor: + """ + instance_tokens supports: + - (B, N, D): static tokens for the whole sequence + - (B, T, N, D): tokens per patch-time (sequence assumed laid out as T contiguous chunks) + - (B, L, N, D): tokens per token position (used for sequence parallel chunking) + """ + B, L, _ = x.shape + if instance_tokens.ndim == 3: + return self._attend(x, instance_tokens, instance_masks) + + if instance_tokens.ndim != 4: + raise ValueError(f"instance_tokens must be 3D or 4D, got {tuple(instance_tokens.shape)}") + + if instance_tokens.shape[1] == L: + # per-token instance tokens: (B, L, N, D) + _, _, N, _ = instance_tokens.shape + if instance_masks.shape != (B, N, L): + raise ValueError(f"instance_masks shape mismatch, expect (B,N,L)=({B},{N},{L}), got {tuple(instance_masks.shape)}") + h = self.num_heads + q = rearrange(self.to_q(self.norm(x)), "b l (h d) -> b h l d", h=h) + k = rearrange(self.to_k(instance_tokens), "b l n (h d) -> b h l n d", h=h) + v = rearrange(self.to_v(instance_tokens), "b l n (h d) -> b h l n d", h=h) + sim = torch.einsum("b h l d, b h l n d -> b h l n", q, k) * self.scale + mask_bias = instance_masks.transpose(1, 2).unsqueeze(1).to(dtype=sim.dtype) + sim = sim + torch.log(mask_bias.clamp(min=1e-6)) + attn = sim.softmax(dim=-1) + out = torch.einsum("b h l n, b h l n d -> b h l d", attn, v) + out = rearrange(out, "b h l d -> b l (h d)") + return x + self.gate * self.to_out(out) + + # per-time instance tokens: (B, T, N, D) + _, T, _, _ = instance_tokens.shape + if L % T != 0: + raise ValueError(f"Token length L={L} must be divisible by T={T} for per-time instance tokens.") + hw = L // T + chunks = [] + for t in range(T): + s, e = t * hw, (t + 1) * hw + chunks.append(self._attend(x[:, s:e], instance_tokens[:, t], instance_masks[:, :, s:e])) + return torch.cat(chunks, dim=1) + + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) + self.instance_cross_attn = MaskGuidedCrossAttention(dim, num_heads, eps) + + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim), + ) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs, instance_tokens=None, instance_masks=None): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod + ).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) + + # Self-attention with AdaLN modulation + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + + # Text / image cross-attention + x = x + self.cross_attn(self.norm3(x), context) + + # Instance-guided cross-attention + if instance_tokens is not None and instance_masks is not None: + x = self.instance_cross_attn(x, instance_tokens, instance_masks) + + # FFN with AdaLN modulation + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = ( + self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2) + ).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + scale) + shift) + return x + + +class InstanceFeatureExtractor(nn.Module): + """ + 将 `instance_id` 与 (class/state 组合短语) 的文本语义融合为实例 token,并支持按时间(帧/patch-time) + 的 state weights 做动态加权: + - 输入:`state_text_embeds_multi` 形状 (B, N, S, text_dim),其中每个 state 对应短语 `" is "` + - 输入:`state_weights` 形状 (B, N, F, S),F 为帧数(或任意时间长度),每帧对 S 个 state 的权重 + - 输出:实例 token 形状 (B, T, N, D),T 为时间长度(可选下采样到 patch-time) + """ + def __init__( + self, + num_instance_ids: int = 1000, + embedding_dim: int = 1280, + hidden_dim: int = 1280, + text_dim: int = 4096, + ): + super().__init__() + self.inst_id_emb = nn.Embedding(num_instance_ids, hidden_dim, padding_idx=0) + self.text_proj = nn.Sequential( + nn.Linear(int(text_dim), hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim, bias=False), + nn.LayerNorm(hidden_dim), + ) + + self.fusion = nn.Sequential( + nn.Linear(hidden_dim * 2, embedding_dim), + nn.SiLU(), + nn.Linear(embedding_dim, embedding_dim), + nn.LayerNorm(embedding_dim), + ) + + @staticmethod + def _pool_time_to_patches(weights: torch.Tensor, num_time_patches: int) -> torch.Tensor: + """ + Average-pool per-frame weights (B,N,F,S) to per-patch-time weights (B,N,T,S), + where mapping uses pt = floor(t * T / F). + """ + if weights.ndim != 4: + raise ValueError(f"state_weights must be (B,N,F,S), got {tuple(weights.shape)}") + B, N, F, S = weights.shape + T = int(num_time_patches) + if T <= 0: + raise ValueError("num_time_patches must be > 0") + if F == T: + return weights + device = weights.device + idx = (torch.arange(F, device=device, dtype=torch.float32) * (T / max(float(F), 1.0))).floor().clamp(0, T - 1).long() + idx = idx.view(1, 1, F, 1).expand(B, N, F, S) + out = torch.zeros((B, N, T, S), device=device, dtype=weights.dtype) + out.scatter_add_(2, idx, weights) + cnt = torch.zeros((B, N, T, S), device=device, dtype=weights.dtype) + cnt.scatter_add_(2, idx, torch.ones_like(weights)) + return out / cnt.clamp(min=1.0) + + def forward( + self, + instance_ids: torch.Tensor, + state_text_embeds_multi: torch.Tensor, + state_weights: torch.Tensor, + num_time_patches: Optional[int] = None, + ): + if state_text_embeds_multi is None: + raise ValueError("state_text_embeds_multi is required.") + if state_weights is None: + raise ValueError("state_weights is required.") + if state_text_embeds_multi.ndim != 4: + raise ValueError(f"state_text_embeds_multi must be (B,N,S,D), got {tuple(state_text_embeds_multi.shape)}") + if state_weights.ndim != 4: + raise ValueError(f"state_weights must be (B,N,F,S), got {tuple(state_weights.shape)}") + + B, N, S, _ = state_text_embeds_multi.shape + if instance_ids.shape[:2] != (B, N): + raise ValueError(f"instance_ids must be (B,N)=({B},{N}), got {tuple(instance_ids.shape)}") + if state_weights.shape[0] != B or state_weights.shape[1] != N or state_weights.shape[-1] != S: + raise ValueError(f"state_weights must be (B,N,F,S)=({B},{N},F,{S}), got {tuple(state_weights.shape)}") + + sem_multi = self.text_proj(state_text_embeds_multi) # (B,N,S,H) + weights = state_weights.to(dtype=sem_multi.dtype, device=sem_multi.device).clamp(min=0) + if num_time_patches is not None: + weights = self._pool_time_to_patches(weights, int(num_time_patches)) + # (B,N,T,S,H) -> (B,N,T,H) + sem_multi = sem_multi.unsqueeze(2) + weights = weights.unsqueeze(-1) + denom = weights.sum(dim=3).clamp(min=1e-6) + sem_time = (sem_multi * weights).sum(dim=3) / denom # (B,N,T,H) + + i_feat = self.inst_id_emb(instance_ids % self.inst_id_emb.num_embeddings).to(dtype=sem_time.dtype, device=sem_time.device) # (B,N,H) + i_time = i_feat.unsqueeze(2).expand(-1, -1, sem_time.shape[2], -1) + tokens = self.fusion(torch.cat([sem_time, i_time], dim=-1)) # (B,N,T,D) + return tokens.permute(0, 2, 1, 3).contiguous() # (B,T,N,D) + + +# ----------------------------------------------------------------------------- +# Main model +# ----------------------------------------------------------------------------- +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + # instance control + num_class_ids: int = 200, + num_state_ids: int = 100, + num_instance_ids: int = 1000, + state_feature_dim: int = 256, + instance_text_dim: Optional[int] = 4096, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate="tanh"), + nn.Linear(dim, dim), + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + self.blocks = nn.ModuleList([DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None + + instance_text_dim = int(text_dim) if instance_text_dim is None else int(instance_text_dim) + self.instance_encoder = InstanceFeatureExtractor( + num_instance_ids=num_instance_ids, + embedding_dim=dim, + hidden_dim=dim, + text_dim=instance_text_dim, + ) + self.instance_text_dim = instance_text_dim + + # ------------------------------ patch helpers ------------------------------ # + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): + """ + Returns: + tokens: (B, L, D) + grid_size: (F_p, H_p, W_p) + """ + x = self.patch_embedding(x) # (B, D, F_p, H_p, W_p) + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + if isinstance(y_camera, (list, tuple)): + x = x + y_camera[0] + else: + x = x + y_camera + grid_size = x.shape[2:] + x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() + return x, grid_size + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2], + ) + + # ------------------------------ masks ------------------------------ # + def process_masks( + self, + grid_size, + image_size: Tuple[int, int, int], + bboxes: torch.Tensor, + bbox_mask: Optional[torch.Tensor] = None, + ): + """ + bbox-only path: + bboxes: (B, N, F, 4) or (B, N, 4), xyxy in pixel coords + bbox_mask: (B, N, F) or (B, N, 1), optional existence mask + Returns: + (B, N, L) flattened patch mask + """ + if bboxes is None: + raise ValueError("bboxes must be provided for instance control.") + return self._bboxes_to_masks(bboxes, bbox_mask, grid_size, image_size) + + def _bboxes_to_masks( + self, + bboxes: torch.Tensor, + bbox_mask: Optional[torch.Tensor], + grid_size: Tuple[int, int, int], + image_size: Tuple[int, int, int], + ): + f_p, h_p, w_p = grid_size + F_img, H_img, W_img = image_size + # Notes on coordinate space: + # - bboxes are interpreted in the same pixel space as (H_img, W_img) + # - projection to patch grid uses ratio (w_p / W_img) and (h_p / H_img) + # - time index is mapped by ratio (f_p / F_bbox) + + if bboxes.ndim == 3: # (B, N, 4) -> broadcast to frames + bboxes = bboxes.unsqueeze(2).expand(-1, -1, f_p, -1) + if bboxes.ndim != 4 or bboxes.shape[-1] != 4: + raise ValueError(f"bboxes must be (B,N,F,4) or (B,N,4); got {tuple(bboxes.shape)}") + + if bbox_mask is None: + bbox_mask = torch.ones(bboxes.shape[:3], device=bboxes.device, dtype=torch.bool) + else: + if bbox_mask.ndim == 3: + pass + elif bbox_mask.ndim == 2: + bbox_mask = bbox_mask.unsqueeze(-1).expand(-1, -1, bboxes.shape[2]) + else: + raise ValueError(f"bbox_mask must be (B,N,F) or (B,N,1); got {tuple(bbox_mask.shape)}") + bbox_mask = bbox_mask.to(dtype=torch.bool, device=bboxes.device) + + mask = bboxes.new_zeros((bboxes.shape[0], bboxes.shape[1], f_p, h_p, w_p), dtype=torch.float32) + f_bbox = int(bboxes.shape[2]) + w_scale = (w_p / max(float(W_img), 1.0)) + h_scale = (h_p / max(float(H_img), 1.0)) + + for b in range(bboxes.shape[0]): + for n in range(bboxes.shape[1]): + for t in range(f_bbox): + if not bbox_mask[b, n, t]: + continue + x0, y0, x1, y1 = bboxes[b, n, t] + x0 = max(0, min(float(x0), W_img)) + x1 = max(0, min(float(x1), W_img)) + y0 = max(0, min(float(y0), H_img)) + y1 = max(0, min(float(y1), H_img)) + if x1 <= x0 or y1 <= y0: + continue + + px0 = int(math.floor(x0 * w_scale)) + py0 = int(math.floor(y0 * h_scale)) + px1 = int(math.ceil(x1 * w_scale)) + py1 = int(math.ceil(y1 * h_scale)) + px0 = max(0, min(px0, w_p - 1)) + py0 = max(0, min(py0, h_p - 1)) + px1 = max(px0 + 1, min(px1, w_p)) + py1 = max(py0 + 1, min(py1, h_p)) + + pt = min(int(math.floor(t * f_p / max(f_bbox, 1))), f_p - 1) + mask[b, n, pt, py0:py1, px0:px1] = 1.0 + + mask_flat = rearrange(mask, "b n f h w -> b n (f h w)") + return mask_flat + + # ------------------------------ forward ------------------------------ # + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + # instance inputs (bbox-based) + instance_ids: Optional[torch.Tensor] = None, # (B, N) + instance_state_text_embeds_multi: Optional[torch.Tensor] = None, # (B, N, S, text_dim) + instance_state_weights: Optional[torch.Tensor] = None, # (B, N, F, S) weights per frame + instance_bboxes: Optional[torch.Tensor] = None, # (B, N, F, 4) + **kwargs, + ): + # Timestep embedding + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + + # Text embedding + context = self.text_embedding(context) + + # Image conditioning + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (B, Cx+Cy, F, H, W) + clip_embedding = self.img_emb(clip_feature) + context = torch.cat([clip_embedding, context], dim=1) + + orig_frames, orig_h, orig_w = x.shape[2:] + x, (f, h, w) = self.patchify(x) + grid_size = (f, h, w) + + # Instance control + inst_tokens = None + inst_mask_flat = None + has_instance = ( + instance_ids is not None + and instance_bboxes is not None + and instance_state_text_embeds_multi is not None + and instance_state_weights is not None + and instance_ids.shape[1] > 0 + ) + if has_instance: + inst_tokens = self.instance_encoder( + instance_ids=instance_ids, + state_text_embeds_multi=instance_state_text_embeds_multi, + state_weights=instance_state_weights, + num_time_patches=f, + ) + + inst_mask_flat = self.process_masks( + grid_size, + image_size=(orig_frames, orig_h, orig_w), + bboxes=instance_bboxes, + ) + + # RoPE + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_with_instance(module): + def custom_forward(x, context, t_mod, freqs, instance_tokens, instance_masks): + return module(x, context, t_mod, freqs, instance_tokens=instance_tokens, instance_masks=instance_masks) + return custom_forward + + for block in self.blocks: + use_instance = inst_tokens is not None and inst_mask_flat is not None + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + if use_instance: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_with_instance(block), + x, context, t_mod, freqs, inst_tokens, inst_mask_flat, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if use_instance: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_with_instance(block), + x, context, t_mod, freqs, inst_tokens, inst_mask_flat, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if use_instance: + x = block(x, context, t_mod, freqs, instance_tokens=inst_tokens, instance_masks=inst_mask_flat) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x diff --git a/diffsynth/models/wan_video_dit_instancev.py b/diffsynth/models/wan_video_dit_instancev.py new file mode 100644 index 0000000000000000000000000000000000000000..99b1c74008695aa27ff75e74f2edf8101c08c230 --- /dev/null +++ b/diffsynth/models/wan_video_dit_instancev.py @@ -0,0 +1,693 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional +from einops import rearrange +from .wan_video_camera_controller import SimpleAdapter +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x,tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + +def scaled_dot_product_attention_with_mask( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_heads: int, + attn_mask: Optional[torch.Tensor], +): + """Always uses PyTorch SDPA because FlashAttention variants may not support arbitrary masks. + + Args: + q,k,v: (B, S, D) + attn_mask: float mask broadcastable to (B, num_heads, Sq, Sk) with 0 for allowed, -inf for disallowed + or bool mask broadcastable where False indicates disallowed. + """ + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + if attn_mask is not None: + # Make it broadcastable to (B, n, Sq, Sk) + if attn_mask.dtype == torch.bool: + mask = attn_mask + else: + mask = attn_mask + if attn_mask.dim() == 3: + mask = mask.unsqueeze(1) + elif attn_mask.dim() != 4: + raise ValueError(f"attn_mask must be 3D or 4D, got shape={attn_mask.shape}") + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + else: + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +class MaskedCrossAttention(nn.Module): + """Cross-attention with explicit attention mask support (used by IMCA).""" + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, attn_mask: Optional[torch.Tensor]): + q = self.norm_q(self.q(x_q)) + k = self.norm_k(self.k(x_kv)) + v = self.v(x_kv) + x = scaled_dot_product_attention_with_mask(q, k, v, num_heads=self.num_heads, attn_mask=attn_mask) + return self.o(x) + + +class SharedTimestepAdaptivePromptEnhancement(nn.Module): + """STAPE: I = I + (m_t + alpha1) * CrossAttn(I, T).""" + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.attn = CrossAttention(dim, num_heads, eps=eps, has_image_input=False) + # Dins-dimensional learnable residual gate (initialized to 0 for stability) + self.mt = nn.Parameter(torch.zeros(1, dim)) + + def forward(self, instance_tokens: torch.Tensor, caption_tokens: torch.Tensor, alpha1: torch.Tensor): + # instance_tokens: (B, F, Nins, D) caption_tokens: (B, Nctx, D) alpha1: (B, D) + B, F_, Nins, D = instance_tokens.shape + I = instance_tokens.reshape(B, F_ * Nins, D) + delta = self.attn(I, caption_tokens) # (B, F*Nins, D) + gate = (self.mt.to(dtype=I.dtype, device=I.device) + alpha1).unsqueeze(1) # (B, 1, D) + I = I + gate * delta + return I.reshape(B, F_, Nins, D) + + +class InstanceAwareMaskedCrossAttention(nn.Module): + """IMCA: masked cross-attention from visual tokens to instance prompt tokens.""" + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.attn = MaskedCrossAttention(dim, num_heads, eps=eps) + + def forward(self, visual_tokens: torch.Tensor, instance_tokens: torch.Tensor, attn_mask: torch.Tensor): + """Args: + visual_tokens: (B, F*HW, D) + instance_tokens: (B, F, Nins, D) + attn_mask: (B, F, Nins, HW) bool OR float, where True/1 means instance-token attends this visual token. + Returns: + (B, F*HW, D) + """ + B, Nv, D = visual_tokens.shape + _, F_, Nins, _ = instance_tokens.shape + HW = Nv // F_ + V = visual_tokens.reshape(B, F_, HW, D) + I = instance_tokens + # Convert mask to (B*F, HW, Nins) with 0 / -inf + M = attn_mask + if M.shape[-1] != HW: + raise ValueError(f"attn_mask last dim must be HW={HW}, got {M.shape[-1]}") + # (B,F,Nins,HW) -> (B,F,HW,Nins) + M = M.permute(0, 1, 3, 2).contiguous() + # 使用与 visual_tokens 相同的 dtype(通常是 bfloat16) + target_dtype = visual_tokens.dtype + if M.dtype == torch.bool: + sdpa_mask = torch.where(M, torch.zeros((), device=M.device, dtype=target_dtype), + torch.full((), float("-inf"), device=M.device, dtype=target_dtype)) + else: + # assume already 0/-inf or similar + sdpa_mask = M.to(dtype=target_dtype) + # Merge batch and frame + V_bf = V.reshape(B * F_, HW, D) + I_bf = I.reshape(B * F_, Nins, D) + sdpa_mask_bf = sdpa_mask.reshape(B * F_, HW, Nins) + out = self.attn(V_bf, I_bf, sdpa_mask_bf) + return out.reshape(B, F_ * HW, D) + + + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + + +class DiTBlock(nn.Module): + def __init__( + self, + has_image_input: bool, + dim: int, + num_heads: int, + ffn_dim: int, + eps: float = 1e-6, + enable_instancev: bool = False, + stape: Optional[SharedTimestepAdaptivePromptEnhancement] = None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.enable_instancev = enable_instancev + self.stape = stape + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) + + # IMCA is inserted between self-attention and cross-attention as a residual branch + if enable_instancev: + self.imca = InstanceAwareMaskedCrossAttention(dim, num_heads, eps=eps) + # zero-initialized gated parameter m_v (paper Eq. 4) + self.mv = nn.Parameter(torch.zeros(1)) + self.norm_imca = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + else: + self.imca = None + self.mv = None + self.norm_imca = None + + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim), + ) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + # Better initialization for IMCA: copy weights from the native cross-attention (paper discussion) + if enable_instancev and self.imca is not None: + self._init_imca_from_cross_attention() + + def _init_imca_from_cross_attention(self): + # copy q,k,v,o and norms + try: + self.imca.attn.q.load_state_dict(self.cross_attn.q.state_dict()) + self.imca.attn.k.load_state_dict(self.cross_attn.k.state_dict()) + self.imca.attn.v.load_state_dict(self.cross_attn.v.state_dict()) + self.imca.attn.o.load_state_dict(self.cross_attn.o.state_dict()) + self.imca.attn.norm_q.load_state_dict(self.cross_attn.norm_q.state_dict()) + self.imca.attn.norm_k.load_state_dict(self.cross_attn.norm_k.state_dict()) + except Exception: + # if anything mismatches, skip silently (keeps compatibility) + pass + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor, + t_mod: torch.Tensor, + freqs: torch.Tensor, + instance_tokens: Optional[torch.Tensor] = None, + instance_attn_mask: Optional[torch.Tensor] = None, + empty_instance_tokens: Optional[torch.Tensor] = None, + saug_drop_prob: float = 0.0, + ): + """Args: + x: (B, F*H*W, D) + context: global caption tokens T after embedding (B, Nctx, D) + instance_tokens: I (B, F, Nins, D) after embedding + instance_attn_mask: M (B, F, Nins, H*W) bool/float + empty_instance_tokens: used for SAUG unconditional branch (same shape as instance_tokens) + """ + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod + ).chunk(6, dim=chunk_dim) + + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) + + # 1) Self-attention (paper Eq. 3) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + + # 2) IMCA (paper Eq. 4) + STAPE (paper Eq. 6) + if self.enable_instancev and (self.imca is not None) and (instance_tokens is not None) and (instance_attn_mask is not None): + # SAUG training-time drop: keep spatial masks but empty the instance prompts with probability p + if isinstance(saug_drop_prob, torch.Tensor): + saug_p = float(saug_drop_prob.detach().cpu().item()) + else: + saug_p = float(saug_drop_prob) + + if self.training and saug_p > 0.0 and empty_instance_tokens is not None: + if torch.rand((), device=x.device) < saug_p: + instance_tokens_use = empty_instance_tokens + else: + instance_tokens_use = instance_tokens + else: + instance_tokens_use = instance_tokens + + # STAPE is shared across blocks (paper) + if self.stape is not None: + # reuse one AdaLN modulation vector as alpha1 (paper) + alpha1 = gate_msa # (B, D) + instance_tokens_use = self.stape(instance_tokens_use, context, alpha1=alpha1) + + # IMCA: masked cross-attn from visual tokens to instance tokens + imca_out = self.imca(self.norm_imca(x), instance_tokens_use, instance_attn_mask) + x = x + self.mv.to(dtype=x.dtype, device=x.device) * imca_out + + # 3) Native cross-attention with global caption tokens (paper Eq. 5) + x = x + self.cross_attn(self.norm3(x), context) + + # 4) FFN + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + enable_instancev: bool = False, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + + + self.enable_instancev = enable_instancev + if enable_instancev: + # STAPE is shared across all DiT blocks (paper Section 4.2) + self.stape = SharedTimestepAdaptivePromptEnhancement(dim=dim, num_heads=num_heads, eps=eps) + else: + self.stape = None + + self.blocks = nn.ModuleList([ + DiTBlock( + has_image_input=has_image_input, + dim=dim, + num_heads=num_heads, + ffn_dim=ffn_dim, + eps=eps, + enable_instancev=enable_instancev, + stape=self.stape, + ) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): + x = self.patch_embedding(x) + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) + return x + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + instance_prompt_tokens: Optional[torch.Tensor] = None, + instance_attn_mask: Optional[torch.Tensor] = None, + empty_instance_prompt_tokens: Optional[torch.Tensor] = None, + saug_drop_prob: float = 0.0, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + # Instance prompt tokens (paper Section 4.1): encode each instance prompt with the same text embedding layer + if instance_prompt_tokens is not None: + instance_tokens = self.text_embedding(instance_prompt_tokens) + else: + instance_tokens = None + + if empty_instance_prompt_tokens is not None: + empty_instance_tokens = self.text_embedding(empty_instance_prompt_tokens) + else: + empty_instance_tokens = None + + + # If SAUG unconditional tokens are not provided but InstanceV is enabled, default to zeros. + # (For best results, provide the pretrained tokens as described in the paper.) + if self.enable_instancev and (instance_tokens is not None) and (empty_instance_tokens is None): + empty_instance_tokens = torch.zeros_like(instance_tokens) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + + for block in self.blocks: + use_instancev_inputs = (self.enable_instancev and (instance_tokens is not None) and (instance_attn_mask is not None)) + if self.training and use_gradient_checkpointing: + if use_instancev_inputs: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + instance_tokens, instance_attn_mask, empty_instance_tokens, + torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype), + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + instance_tokens, instance_attn_mask, empty_instance_tokens, + torch.tensor(float(saug_drop_prob), device=x.device, dtype=x.dtype), + use_reentrant=False, + ) + else: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if use_instancev_inputs: + x = block(x, context, t_mod, freqs, instance_tokens, instance_attn_mask, empty_instance_tokens, saug_drop_prob) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x + + +def apply_saug(eps_cond: torch.Tensor, eps_uncond: torch.Tensor, w: float) -> torch.Tensor: + """Spatially-Aware Unconditional Guidance (SAUG), paper Eq. (7): + eps_tilde = (1 + w) * eps_cond - w * eps_uncond + where eps_uncond is predicted with *empty instance prompts* but the same spatial masks. + """ + return (1.0 + w) * eps_cond - w * eps_uncond + diff --git a/diffsynth/models/wan_video_dit_mvid.py b/diffsynth/models/wan_video_dit_mvid.py new file mode 100644 index 0000000000000000000000000000000000000000..e44715de2b02c3869970b0a1b504ff5bb1c5508f --- /dev/null +++ b/diffsynth/models/wan_video_dit_mvid.py @@ -0,0 +1,1046 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional, List, Dict, Sequence +from einops import rearrange +from .utils import hash_state_dict_keys +from .wan_video_camera_controller import SimpleAdapter +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +print("FLASH_ATTN_3_AVAILABLE ",FLASH_ATTN_3_AVAILABLE) +print("FLASH_ATTN_2_AVAILABLE",FLASH_ATTN_2_AVAILABLE) +try: + from flash_attn_interface import flash_attn_varlen_func +except: + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + except Exception as e: + flash_attn_varlen_func = None + + +# def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): +# if compatibility_mode: +# q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) +# k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) +# v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) +# x = F.scaled_dot_product_attention(q, k, v) +# x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) +# elif FLASH_ATTN_3_AVAILABLE: +# q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) +# k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) +# v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) +# x = flash_attn_interface.flash_attn_func(q, k, v) +# if isinstance(x,tuple): +# x = x[0] +# x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) +# elif FLASH_ATTN_2_AVAILABLE: +# q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) +# k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) +# v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) +# x = flash_attn.flash_attn_func(q, k, v) +# x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) +# elif SAGE_ATTN_AVAILABLE: +# q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) +# k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) +# v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) +# x = sageattn(q, k, v) +# x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) +# else: +# q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) +# k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) +# v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) +# x = F.scaled_dot_product_attention(q, k, v) +# x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) +# return x + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attn_mask=None, shot_latent_indices=None): + + if attn_mask is not None: + + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v, attn_mask = attn_mask) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + if shot_latent_indices is not None: + + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x,tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + # print("flas_attn_2") + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def build_global_reps_from_shots( + K_local_shots: List[torch.Tensor], + V_local_shots: List[torch.Tensor], + g_per: int, + mode: str = "firstk" # "mean" | "firstk" | "linspace" +): + """ + 简单的代表池构造:从每个 shot 的本地 K/V 生成若干代表 token,并拼成共享池。 + K_local_shots[i]: [Ni, H, D] + 返回: + K_global: [G_total, H, D], V_global: [G_total, H, D] + """ + reps_k, reps_v = [], [] + S = len(K_local_shots) + if S == 0: + return (torch.empty(0), torch.empty(0)) + + # g_per = max(1, G // S) if G > 0 else 0 + G = g_per * S + + for Ki, Vi in zip(K_local_shots, V_local_shots): + Ni = Ki.size(0) + if Ni == 0 or g_per == 0: + continue + if mode == "mean": + idx = torch.linspace(0, Ni - 1, steps=g_per, device=Ki.device).long() + reps_k.append(Ki.index_select(0, idx)) + reps_v.append(Vi.index_select(0, idx)) + elif mode == "firstk": + take = min(g_per, Ni) + reps_k.append(Ki[:take]) + reps_v.append(Vi[:take]) + elif mode == "linspace": + idx = torch.linspace(0, Ni - 1, steps=g_per, device=Ki.device).long() + reps_k.append(Ki.index_select(0, idx)) + reps_v.append(Vi.index_select(0, idx)) + else: + raise ValueError(f"unknown mode {mode}") + if len(reps_k) == 0: + return (torch.empty(0, *K_local_shots[0].shape[1:], device=K_local_shots[0].device, dtype=K_local_shots[0].dtype), + torch.empty(0, *V_local_shots[0].shape[1:], device=V_local_shots[0].device, dtype=V_local_shots[0].dtype)) + K_global = torch.cat(reps_k, dim=0) + V_global = torch.cat(reps_v, dim=0) + + return K_global, V_global + +def build_ID_reps( + shot_2_IDs: Dict[int, List[int]], + K_shots: List[torch.Tensor], # each: [Ni, H, D] + V_shots: List[torch.Tensor], # each: [Ni, H, D] +): + """ + shot_2_IDs: + { + shot_id: [id_shot_id_1, id_shot_id_2, ...] # ✅ 这里的 ID 是“特殊shot”的下标 + } + + Returns: + shot_id -> {"K": K_id, "V": V_id} + 其中 K_id/V_id 是该 shot 关联的所有 ID-shot 的 token 拼起来的结果: + K_id: [sum(N_id), H, D] + V_id: [sum(N_id), H, D] + """ + shot_id_kv = {} + + for shot_id, id_shot_ids in shot_2_IDs.items(): + reps_k, reps_v = [], [] + + for id_sid in id_shot_ids: + if id_sid < 0 or id_sid >= len(K_shots): + continue + + Ki = K_shots[id_sid] # [Ni, H, D] (ID-shot) + Vi = V_shots[id_sid] + + # Ki 可能为空(比如 padding / 没有token) + if Ki is None or Vi is None or Ki.numel() == 0: + continue + + reps_k.append(Ki) + reps_v.append(Vi) + + if len(reps_k) == 0: + # 没有任何 ID-shot 可用:返回空(保持 dtype/device 一致更安全) + # 如果你希望直接不返回该 shot,也可以改成 `continue` + device = K_shots[0].device + dtype = K_shots[0].dtype + shot_id_kv[shot_id] = { + "K": torch.empty(0, *K_shots[0].shape[1:], device=device, dtype=dtype), + "V": torch.empty(0, *V_shots[0].shape[1:], device=device, dtype=dtype), + } + continue + + shot_id_kv[shot_id] = { + "K": torch.cat(reps_k, dim=0), + "V": torch.cat(reps_v, dim=0), + } + + return shot_id_kv + +def attention_per_batch_with_shots( + q: torch.Tensor, # [b, s, n_heads*head_dim] + k: torch.Tensor, # [b, s, n_heads*head_dim] + v: torch.Tensor, # [b, s, n_heads*head_dim] + shot_latent_indices: Sequence[Sequence[int]], + num_heads: int, + # use_shared_global: bool = True, + per_g: int = 64, + # G_per_shot: int = 0, + dropout_p: float = 0.0, + causal: bool = False +): + + assert q.shape == k.shape == v.shape + b, s_tot, hd = q.shape + assert hd % num_heads == 0 + d = hd // num_heads + dtype = q.dtype + device = q.device + + + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads).contiguous() + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads).contiguous() + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads).contiguous() + + outputs = [] + + if flash_attn_varlen_func is None: + raise RuntimeError("flash_attn_varlen_func not available. Please install flash-attn v2+.") + + for bi in range(b): + + cuts = list(shot_latent_indices[bi]) + assert cuts[0] == 0 and cuts[-1] == s_tot, "shot_latent_indices must start with 0 and end with s_tot" + + Q_shots, K_shots, V_shots = [], [], [] + N_list = [] + for a, bnd in zip(cuts[:-1], cuts[1:]): + Q_shots.append(q[bi, :, a:bnd, :]) # [n, Ni, d] + K_shots.append(k[bi, :, a:bnd, :]) + V_shots.append(v[bi, :, a:bnd, :]) + N_list.append(bnd - a) + + + Q_locals = [rearrange(Qi, "n s d -> s n d") for Qi in Q_shots] + K_locals = [rearrange(Ki, "n s d -> s n d") for Ki in K_shots] + V_locals = [rearrange(Vi, "n s d -> s n d") for Vi in V_shots] + + ## 先不加镜头间的交互 + # K_global, V_global = build_global_reps_from_shots(K_locals, V_locals, per_g, mode="firstk") + + # K_list = [torch.cat([K_locals[i], K_global], dim=0) for i in range(len(K_locals))] + # V_list = [torch.cat([V_locals[i], V_global], dim=0) for i in range(len(V_locals))] + # kv_lengths = [Ni + K_global.size(0) for Ni in N_list] + + ### 把镜头和相关的人脸token 拼在一起, 现在还没有区分不同角度人脸对于镜头的重要性 + K_list = [] + V_list = [] + kv_lengths = [] + for i in range(len(K_locals)): + K_i = K_locals[i] # [Ni, H, D] + V_i = V_locals[i] + + if i in K_ID_per_shot: + K_id = K_ID_per_shot[i]["K"] # [Nid, H, D] + V_id = K_ID_per_shot[i]["V"] + K_cat = torch.cat([K_i, K_id], dim=0) + V_cat = torch.cat([V_i, V_id], dim=0) + kv_len = K_i.size(0) + K_id.size(0) + else: + # 这个 shot 没有任何 ID + K_cat = K_i + V_cat = V_i + kv_len = K_i.size(0) + + K_list.append(K_cat) + V_list.append(V_cat) + kv_lengths.append(kv_len) + + + Q_packed = torch.cat(Q_locals, dim=0) # [sum_N, n, d] + K_packed = torch.cat(K_list, dim=0) # [sum_(N+G), n, d] + V_packed = torch.cat(V_list, dim=0) # [sum_(N+G), n, d] + + Sshots = len(N_list) + q_seqlens = torch.tensor([0] + [sum(N_list[:i+1]) for i in range(Sshots)], + device=device, dtype=torch.int32) + kv_seqlens = torch.tensor([0] + [sum(kv_lengths[:i+1]) for i in range(Sshots)], + device=device, dtype=torch.int32) + max_q_seqlen = max(N_list) if len(N_list) > 0 else 0 + max_kv_seqlen = max(kv_lengths) if len(kv_lengths) > 0 else 0 + + + O_packed = flash_attn_varlen_func( + Q_packed, K_packed, V_packed, + q_seqlens, kv_seqlens, + max_q_seqlen, max_kv_seqlen, + softmax_scale=None, causal=causal + ) # [sum_N, n, d] + + + O_list = [] + for i in range(Sshots): + st = q_seqlens[i].item() + ed = q_seqlens[i+1].item() + Oi = O_packed[st:ed] # [Ni, n, d] + O_list.append(Oi) + O_local = torch.cat(O_list, dim=0) # [s_tot, n, d] + O_local = rearrange(O_local, "s n d -> n s d").contiguous() # [n, s, d] + outputs.append(O_local) + + + x = torch.stack(outputs, dim=0) # [b, n, s, d] + x = rearrange(x, "b n s d -> b s (n d)") + return x + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + +def precompute_freqs_cis_4d(dim: int, end: int = 1024, theta: float = 10000.0): + ### shot 的频率要不要和f h w 不一样???? + + s_freqs_cis = precompute_freqs_cis(dim - 3 * (dim // 4), end, theta) + f_freqs_cis = precompute_freqs_cis(dim // 4, end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 4, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 4, end, theta) + return s_freqs_cis, f_freqs_cis, h_freqs_cis, w_freqs_cis + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v, attn_mask=None, shot_latent_indices = None, per_g=0): + if attn_mask is not None: + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attn_mask=attn_mask) + elif shot_latent_indices is not None: + x = attention_per_batch_with_shots(q=q, k=k, v=v, shot_latent_indices=shot_latent_indices, num_heads=self.num_heads,per_g=per_g) + else: + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention( + dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + self.shot_freqs = precompute_freqs_cis_4d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): + x = self.patch_embedding(x) + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x + + @staticmethod + def state_dict_converter(): + return WanModelStateDictConverter() + + +class WanModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + else: + config = {} + return state_dict_, config + + def from_civitai(self, state_dict): + state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} + if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 16, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": + # 1.3B PAI control + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": + # 14B PAI control + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_image_pos_emb": True + } + elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504": + # 1.3B PAI control v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6, + "has_ref_conv": True + } + elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b": + # 14B PAI control v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": True + } + elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901": + # 1.3B PAI control-camera v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 32, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + } + elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae": + # 14B PAI control-camera v1.1 + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 32, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + } + elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + # Wan-AI/Wan2.2-TI2V-5B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-6, + "seperated_timestep": True, + "require_clip_embedding": False, + "require_vae_embedding": False, + "fuse_vae_embedding_in_latents": True, + } + elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": + # Wan-AI/Wan2.2-I2V-A14B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "require_clip_embedding": False, + } + elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5": + # Wan2.2-Fun-A14B-Control + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 52, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": True, + "require_clip_embedding": False, + } + elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1": + # Wan2.2-Fun-A14B-Control-Camera + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + "require_clip_embedding": False, + } + else: + config = {} + return state_dict, config diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbed8c05a2b5acca4c1ecaa2a68100433c74366 --- /dev/null +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -0,0 +1,594 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple +from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d + + +def torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item() + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1 + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + if need_global: + self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class FramePackMotioner(nn.Module): + + def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split( + list(self.zip_frame_buckets)[::-1], dim=2 + ) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + ] + ] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads), + grid_sizes, + self.freqs, + start=None + ) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + + +class AdaLayerNorm(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_dim: int, + norm_eps: float = 1e-5, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False) + + def forward(self, x, temb): + temb = self.linear(F.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__( + self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + enable_adain=False, + adain_dim=2048, + ): + super().__init__() + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, DiTBlock): + for inject_id in inject_layer: + if f'transformer_blocks.{inject_id}' in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([CrossAttention( + dim=dim, + num_heads=num_heads, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)]) + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False): + super().__init__() + self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(self.weights.to(device=features.device, dtype=features.dtype)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class WanS2VDiTBlock(DiTBlock): + + def forward(self, x, context, t_mod, seq_len_x, freqs): + t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. + t_mod = [ + torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + for element in t_mod + ] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class WanS2VModel(torch.nn.Module): + + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + cond_dim: int, + audio_dim: int, + num_audio_token: int, + enable_adain: bool = True, + audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + zero_timestep: bool = True, + add_last_motion: bool = True, + framepack_drop_mode: str = "padd", + fuse_vae_embedding_in_latents: bool = True, + require_vae_embedding: bool = False, + seperated_timestep: bool = False, + require_clip_embedding: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.patch_size = patch_size + self.num_heads = num_heads + self.enbale_adain = enable_adain + self.add_last_motion = add_last_motion + self.zero_timestep = zero_timestep + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + self.require_vae_embedding = require_vae_embedding + self.seperated_timestep = seperated_timestep + self.require_clip_embedding = require_clip_embedding + + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = Head(dim, out_dim, patch_size, eps) + self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1) + + self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) + self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) + all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") + self.audio_injector = AudioInjector_WAN( + all_modules, + all_modules_names, + dim=dim, + num_heads=num_heads, + inject_layer=audio_inject_layers, + enable_adain=enable_adain, + adain_dim=dim, + ) + self.trainable_cond_mask = nn.Embedding(3, dim) + self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode) + + def patchify(self, x: torch.Tensor): + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, + 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], + h=grid_size[1], + w=grid_size[2], + x=self.patch_size[0], + y=self.patch_size[1], + z=self.patch_size[2] + ) + + def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] + else: + return flattern_mot, mot_remb + + def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): + # inject the motion frames token to the hidden states + mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) + if len(mot) > 0: + x = torch.cat([x, mot[0]], dim=1) + rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1) + mask_input = torch.cat( + [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1 + ) + return x, rope_embs, mask_input + + def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + num_frames = audio_emb.shape[1] + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sp_group + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + + input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c + input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + return hidden_states + + def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input) + audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() + merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + return audio_emb_global, merged_audio_emb + + def get_grid_sizes(self, grid_size_x, grid_size_ref): + f, h, w = grid_size_x + rf, rh, rw = grid_size_ref + grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) + grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]] + grid_sizes_ref = [[ + torch.tensor([30, 0, 0]).unsqueeze(0), + torch.tensor([31, rh, rw]).unsqueeze(0), + torch.tensor([1, rh, rw]).unsqueeze(0), + ]] + return grid_sizes_x + grid_sizes_ref + + def forward( + self, + latents, + timestep, + context, + audio_input, + motion_latents, + pose_cond, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False + ): + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = self.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] + + # reference image + ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute( + x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + ) + # motion + x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + + x = x + self.trainable_cond_mask(mask).to(x.dtype) + + # t_mod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(self.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + + x = x[:, :seq_len_x] + x = self.head(x, t[:-1]) + x = self.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/diffsynth/models/wan_video_dit_slots.py b/diffsynth/models/wan_video_dit_slots.py new file mode 100644 index 0000000000000000000000000000000000000000..55e1f94c8caa17c874239d27661bc0d791952f5c --- /dev/null +++ b/diffsynth/models/wan_video_dit_slots.py @@ -0,0 +1,870 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional, Dict, Any +from einops import rearrange +from .wan_video_camera_controller import SimpleAdapter + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_heads: int, + compatibility_mode: bool = False +): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer( + position.type(torch.float64), + torch.pow( + 10000, + -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2), + ), + ) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex( + x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + ) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class SelfAttention(nn.Module): + """原有 SelfAttention:带 RoPE,给 video patch tokens 用""" + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class SelfAttentionNoRoPE(nn.Module): + """给 slots 用的 self-attn:不加 RoPE(slot 没有稳定网格位置信息时更稳)""" + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + """ + x: queries + y: keys/values (context) + """ + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x_out = self.attn(q, k, v) + + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y_img = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x_out = x_out + y_img + + return self.o(x_out) + + +class GateModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + + +# -------------------------- +# ROI gather/scatter helpers +# -------------------------- +def gather_tokens(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + x: (b, s, c) + idx: (b, m) long + return: (b, m, c) + """ + b, s, c = x.shape + return x.gather(1, idx[..., None].expand(b, idx.shape[1], c)) + + +def scatter_tokens(x: torch.Tensor, idx: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + x: (b, s, c) + idx: (b, m) long + y: (b, m, c) + return: (b, s, c) with y written back to idx positions + """ + b, s, c = x.shape + out = x.clone() + out.scatter_(1, idx[..., None].expand(b, idx.shape[1], c), y) + return out + + +def bbox_to_mask(bbox_xyxy: torch.Tensor, H: int, W: int) -> torch.Tensor: + """ + bbox_xyxy: (b, 4) float/int in [0,W) [0,H) + return: (b, 1, H, W) float {0,1} + """ + b = bbox_xyxy.shape[0] + mask = torch.zeros((b, 1, H, W), device=bbox_xyxy.device, dtype=torch.float32) + x1, y1, x2, y2 = bbox_xyxy[:, 0], bbox_xyxy[:, 1], bbox_xyxy[:, 2], bbox_xyxy[:, 3] + x1 = x1.clamp(0, W - 1).long() + x2 = x2.clamp(0, W).long() + y1 = y1.clamp(0, H - 1).long() + y2 = y2.clamp(0, H).long() + for i in range(b): + mask[i, 0, y1[i]:y2[i], x1[i]:x2[i]] = 1.0 + return mask + + +@torch.no_grad() +def mask_to_roi_idx( + mask: torch.Tensor, + grid_fhw: Tuple[int, int, int], + *, + frame_index: int = 0, + roi_token_budget: int = 256, + mode: str = "topk", +) -> torch.Tensor: + """ + 把单帧 mask 下采样到 patch 网格,输出固定长度 roi_idx(gather/scatter 用)。 + + mask: (b, H, W) or (b,1,H,W) —— 推理你说只给一帧,就传这一帧的 mask + grid_fhw: (f, h, w) from patchify + frame_index: 指定这次交互发生在第几帧(默认 0) + roi_token_budget: 固定 m,避免 flash-attn 变长不兼容 + mode: "topk" 或 "random" + """ + if mask.dim() == 3: + mask = mask[:, None] + b, _, H, W = mask.shape + f, h, w = grid_fhw + assert 0 <= frame_index < f, f"frame_index {frame_index} out of range f={f}" + + # 下采样到 patch 网格 (h,w) + m_small = F.interpolate(mask.float(), size=(h, w), mode="bilinear", align_corners=False) # (b,1,h,w) + m_small = m_small[:, 0] # (b,h,w) + flat = m_small.reshape(b, h * w) # (b, h*w) + + # 变成全局 token index:token 顺序是 (f,h,w) flatten + base = frame_index * (h * w) + + if mode == "topk": + scores = flat + k = min(roi_token_budget, h * w) + topv, topi = torch.topk(scores, k=k, dim=1) + # 如果 k < budget,补齐 + if k < roi_token_budget: + pad = topi[:, :1].expand(b, roi_token_budget - k) + topi = torch.cat([topi, pad], dim=1) + idx = topi[:, :roi_token_budget] + base + return idx.long() + + if mode == "random": + # 从非零位置随机采样,不够就重复补齐 + idx_list = [] + for bi in range(b): + nz = torch.nonzero(flat[bi] > 0.01, as_tuple=False).flatten() + if nz.numel() == 0: + # 全 0,退化为随机 + nz = torch.arange(h * w, device=mask.device) + if nz.numel() >= roi_token_budget: + sel = nz[torch.randperm(nz.numel(), device=mask.device)[:roi_token_budget]] + else: + rep = nz[torch.randint(0, nz.numel(), (roi_token_budget,), device=mask.device)] + sel = rep + idx_list.append(sel) + idx = torch.stack(idx_list, dim=0) + base + return idx.long() + + raise ValueError(f"Unknown mode={mode}") + + +# -------------------------- +# Slots-enabled DiT block +# -------------------------- +class DiTBlockWithSlots(nn.Module): + """ + 在原 DiTBlock 基础上加入: + - text -> slots + - roi_patches -> slots + - slots self-attn(实例交互) + - slots -> roi_patches(写回) + 并且默认关闭原来的 patch <- text 全局 cross-attn(避免全局污染)。 + """ + def __init__( + self, + has_image_input: bool, + dim: int, + num_heads: int, + ffn_dim: int, + eps: float = 1e-6, + enable_patch_text_cross_attn: bool = False, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + self.enable_patch_text_cross_attn = enable_patch_text_cross_attn + + # patch path (保持原逻辑) + self.self_attn = SelfAttention(dim, num_heads, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate="tanh"), + nn.Linear(ffn_dim, dim), + ) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + # (可选) patch <- text(原来的 cross-attn) + if enable_patch_text_cross_attn: + self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) + self.norm3 = nn.LayerNorm(dim, eps=eps) + + # slot modules + self.slot_norm = nn.LayerNorm(dim, eps=eps) + self.slot_text_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) # slots <- context + self.slot_from_patch = CrossAttention(dim, num_heads, eps, has_image_input=False) # slots <- roi_patches + self.slot_self = SelfAttentionNoRoPE(dim, num_heads, eps) # slots <-> slots + self.patch_from_slot = CrossAttention(dim, num_heads, eps, has_image_input=False) # roi_patches <- slots + + def forward( + self, + x: torch.Tensor, # (b, s, dim) patch tokens + slots: torch.Tensor, # (b, n_slots, dim) + context: torch.Tensor, # (b, n_ctx, dim) + t_mod: torch.Tensor, # (b, 6, dim) + freqs: torch.Tensor, # (s, 1, rope_dim) complex + roi_idx: Optional[torch.Tensor] = None, # (b, m) + ): + # ---- original patch self-attn + gated adaLN ---- + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod + ).chunk(6, dim=chunk_dim) + + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), + scale_msa.squeeze(2), + gate_msa.squeeze(2), + shift_mlp.squeeze(2), + scale_mlp.squeeze(2), + gate_mlp.squeeze(2), + ) + + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + + # ---- ROI select (空间局部) ---- + if roi_idx is None: + x_roi = x + else: + x_roi = gather_tokens(x, roi_idx) + + # ---- text -> slots (实例受文本/指令影响) ---- + slots = slots + self.slot_text_attn(self.slot_norm(slots), context) + + # ---- video(ROI) -> slots (实例从局部视频读取状态/外观) ---- + slots = slots + self.slot_from_patch(self.slot_norm(slots), x_roi) + + # ---- slots self-attn (实例之间交互) ---- + slots = slots + self.slot_self(self.slot_norm(slots)) + + # ---- slots -> video(ROI) (把交互/状态写回局部视频) ---- + x_roi = x_roi + self.patch_from_slot(self.slot_norm(x_roi), slots) + + if roi_idx is None: + x = x_roi + else: + x = scatter_tokens(x, roi_idx, x_roi) + + # ---- (optional) patch <- text (不推荐默认开) ---- + if self.enable_patch_text_cross_attn: + x = x + self.cross_attn(self.norm3(x), context) + + # ---- original FFN ---- + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + + return x, slots + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim), + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = ( + self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2) + ).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + scale) + shift) + return x + + +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + + # -------- slots args (新增) -------- + enable_slots: bool = True, + num_slots: int = 16, + instance_state_dim: int = 0, # 你的 InstanceCap state 维度 + state_head_dim: int = 0, # 如果 >0,输出 slots->state_pred 用于监督 + enable_patch_text_cross_attn: bool = False, # 不推荐默认开 + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.enable_slots = enable_slots + self.num_slots = num_slots + self.instance_state_dim = instance_state_dim + self.state_head_dim = state_head_dim + + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate="tanh"), + nn.Linear(dim, dim), + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + if enable_slots: + self.blocks = nn.ModuleList([ + DiTBlockWithSlots( + has_image_input=has_image_input, + dim=dim, + num_heads=num_heads, + ffn_dim=ffn_dim, + eps=eps, + enable_patch_text_cross_attn=enable_patch_text_cross_attn, + ) + for _ in range(num_layers) + ]) + else: + # 退回到你原来的 DiTBlock(如需) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + + if add_control_adapter: + self.control_adapter = SimpleAdapter( + in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:] + ) + else: + self.control_adapter = None + + # ---- slots params (新增) ---- + if enable_slots: + self.slot_base = nn.Parameter(torch.randn(1, num_slots, dim) / dim**0.5) + + self.instance_proj = None + if instance_state_dim > 0: + self.instance_proj = nn.Sequential( + nn.LayerNorm(instance_state_dim), + nn.Linear(instance_state_dim, dim), + nn.GELU(), + nn.Linear(dim, dim), + ) + + self.state_head = None + if state_head_dim > 0: + self.state_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim), + nn.GELU(), + nn.Linear(dim, state_head_dim), + ) + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): + """ + return: + tokens: (b, f*h*w, dim) + grid: (f, h, w) + """ + x = self.patch_embedding(x) # (b, dim, f, h, w) + + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + # 兼容 y_camera 可能是 list/tuple 或 tensor 的情况 + if isinstance(y_camera, (list, tuple)): + # 如果你 adapter 返回的是 (b, dim, f, h, w) 的列表 + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) + else: + x = x + y_camera + + f, h, w = x.shape[2], x.shape[3], x.shape[4] + x = rearrange(x, "b c f h w -> b (f h w) c") + return x, (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: Tuple[int, int, int]): + return rearrange( + x, + "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", + f=grid_size[0], + h=grid_size[1], + w=grid_size[2], + x=self.patch_size[0], + y=self.patch_size[1], + z=self.patch_size[2], + ) + + def _init_slots( + self, + batch_size: int, + *, + instance_state: Optional[torch.Tensor] = None, + state_override: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """ + instance_state: + - 训练:建议传 (b, num_slots, state_dim)(已做 slot 对齐/跟踪) + - 或者你也可以先传 (b, n_inst, state_dim) 再在外部做 matching 到 slot + + state_override(推理交互用): + { + "slot_ids": LongTensor (b,) or (b,k), + "state": Tensor (..., state_dim), + "alpha": float (default 1.0), + "hard": bool (default False) + } + """ + slots = self.slot_base.expand(batch_size, -1, -1) # (b, num_slots, dim) + + if (instance_state is not None) and (self.instance_proj is not None): + # instance_state 期望 (b, num_slots, state_dim) + slots = slots + self.instance_proj(instance_state) + + # 推理时:对指定 slot 注入目标 state,实现“状态控制” + if state_override is not None and self.instance_proj is not None: + slot_ids = state_override["slot_ids"] + target_state = state_override["state"] + alpha = float(state_override.get("alpha", 1.0)) + hard = bool(state_override.get("hard", False)) + + # 统一形状:slot_ids -> (b, k) + if slot_ids.dim() == 1: + slot_ids = slot_ids[:, None] # (b,1) + + b = slots.shape[0] + k = slot_ids.shape[1] + + # target_state 支持两种: + # (b, state_dim) -> broadcast to (b,k,state_dim) + # (b,k,state_dim) -> per-slot + if target_state.dim() == 2: + target_state = target_state[:, None, :].expand(b, k, -1) + + delta = self.instance_proj(target_state) # (b,k,dim) + + if hard: + # hard: slot = base + proj(state) + base = self.slot_base.expand(b, -1, -1) + # 先把 base 写入对应 slot 再加 delta + slots = slots.clone() + slots.scatter_(1, slot_ids[..., None].expand(b, k, slots.shape[-1]), + gather_tokens(base, slot_ids) + delta) + else: + # soft: slot += alpha * proj(state) + slots = slots.clone() + cur = gather_tokens(slots, slot_ids) + new = cur + alpha * delta + slots = scatter_tokens(slots, slot_ids, new) + + return slots + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + + # ---- 新增:实例状态/交互输入 ---- + instance_state: Optional[torch.Tensor] = None, # (b, num_slots, state_dim)(训练可 per-frame) + roi_idx: Optional[torch.Tensor] = None, # (b, m)(推理你用 SAM mask->roi_idx) + state_override: Optional[Dict[str, Any]] = None, # 推理交互:对某些 slot 强制状态 + + return_state_pred: bool = False, + **kwargs, + ): + # time + text + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + # optional image input + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # patchify -> tokens + x, (f, h, w) = self.patchify(x, control_camera_latents_input=kwargs.get("control_camera_latents_input", None)) + + # rope freqs for patch self-attn + freqs = torch.cat( + [ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(f * h * w, 1, -1).to(x.device) + + # init slots + slots = None + if self.enable_slots: + slots = self._init_slots( + batch_size=x.shape[0], + instance_state=instance_state, + state_override=state_override, + ) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # blocks + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + if self.enable_slots: + x, slots = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, slots, context, t_mod, freqs, roi_idx, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if self.enable_slots: + x, slots = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, slots, context, t_mod, freqs, roi_idx, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if self.enable_slots: + x, slots = block(x, slots, context, t_mod, freqs, roi_idx=roi_idx) + else: + x = block(x, context, t_mod, freqs) + + # output head + out = self.head(x, t) + out = self.unpatchify(out, (f, h, w)) + + if return_state_pred and self.enable_slots and (self.state_head is not None): + state_pred = self.state_head(slots) # (b, num_slots, state_head_dim) + return out, state_pred, slots + + return out + + +# --------- 原始 DiTBlock(保留以兼容 enable_slots=False)--------- +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate="tanh"), + nn.Linear(ffn_dim, dim), + ) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod + ).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), + scale_msa.squeeze(2), + gate_msa.squeeze(2), + shift_mlp.squeeze(2), + scale_mlp.squeeze(2), + gate_mlp.squeeze(2), + ) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37d17d6a183f5d6290c3f4b1e417bf0310bfa353 --- /dev/null +++ b/diffsynth/models/wan_video_image_encoder.py @@ -0,0 +1,878 @@ +""" +Concise re-implementation of +``https://github.com/openai/CLIP'' and +``https://github.com/mlfoundations/open_clip''. +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from .wan_video_dit import flash_attention + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init model + if pretrained: + from sora import DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = XLMRoberta(**cfg) + + # load checkpoint + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'), + map_location=device), + assign=True) + else: + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + + # init tokenizer + if return_tokenizer: + from sora.data import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.text_len, + clean='whitespace') + return model, tokenizer + else: + return model + + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + # compute query, key, value + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # compute attention + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1) + k, v = self.to_kv(x).chunk(2, dim=-1) + + # compute attention + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + e = e.to(dtype=x.dtype, device=x.device) + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_mlp_ratio=4, + vision_heads=12, + vision_layers=12, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + vocab_size=49408, + text_len=77, + text_dim=512, + text_mlp_ratio=4, + text_heads=8, + text_layers=12, + text_causal=True, + text_pool='argmax', + text_head_bias=False, + logit_bias=None, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pool = vision_pool + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_mlp_ratio = text_mlp_ratio + self.text_heads = text_heads + self.text_layers = text_layers + self.text_causal = text_causal + self.text_pool = text_pool + self.text_head_bias = text_head_bias + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + mlp_ratio=text_mlp_ratio, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + causal=text_causal, + pool_type=text_pool, + head_bias=text_head_bias, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + if logit_bias is not None: + self.logit_bias = nn.Parameter(logit_bias * torch.ones([])) + + # initialize weights + self.init_weights() + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, std=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else self.text_dim + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * len(transformer))) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = None + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=CLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init model + if pretrained and pretrained_name: + from sora import BUCKET, DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = model_cls(**kwargs) + + # checkpoint path + checkpoint = f'models/clip/{pretrained_name}' + if dtype in (torch.float16, torch.bfloat16): + suffix = '-' + { + torch.float16: 'fp16', + torch.bfloat16: 'bf16' + }[dtype] + if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'): + checkpoint = f'{checkpoint}{suffix}' + checkpoint += '.pth' + + # load + model.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device), + assign=True, + strict=False) + else: + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + + # init tokenizer + if return_tokenizer: + from sora import data + if 'siglip' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name=f'timm/{pretrained_name}', + seq_len=model.text_len, + clean='canonicalize') + elif 'xlm' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.max_text_len - 2, + clean='whitespace') + elif 'mba' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='facebook/xlm-roberta-xl', + seq_len=model.max_text_len - 2, + clean='whitespace') + else: + tokenizer = data.CLIPTokenizer( + seq_len=model.text_len, padding=tokenizer_padding) + output += (tokenizer,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class WanImageEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=torch.float32, + device="cpu") + + def encode_image(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u, + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/diffsynth/models/wan_video_mot.py b/diffsynth/models/wan_video_mot.py new file mode 100644 index 0000000000000000000000000000000000000000..4091c91777355dce91ccefac56679f8b936e7abb --- /dev/null +++ b/diffsynth/models/wan_video_mot.py @@ -0,0 +1,169 @@ +import torch +from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP +import einops +import torch.nn as nn + + +class MotSelfAttention(SelfAttention): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__(dim, num_heads, eps) + def forward(self, x, freqs, is_before_attn=False): + if is_before_attn: + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + return q, k, v + else: + return self.o(x) + + +class MotWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + + self.self_attn = MotSelfAttention(dim, num_heads, eps) + + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot): + + # 1. prepare scale parameter + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + + scale_params_mot_ref = self.modulation + t_mod_mot.float() + scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1) + shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2) + + # 2. Self-attention + input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa) + # original block self-attn + attn1 = wan_block.self_attn + q = attn1.norm_q(attn1.q(input_x)) + k = attn1.norm_k(attn1.k(input_x)) + v = attn1.v(input_x) + q = rope_apply(q, freqs, attn1.num_heads) + k = rope_apply(k, freqs, attn1.num_heads) + + # mot block self-attn + norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot) + norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1) + q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True) + + tmp_hidden_states = flash_attention( + torch.cat([q, q_mot], dim=-2), + torch.cat([k, k_mot], dim=-2), + torch.cat([v, v_mot], dim=-2), + num_heads=attn1.num_heads) + + attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2) + + attn_output = attn1.o(attn_output) + x = wan_block.gate(x, gate_msa, attn_output) + + attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False) + # gate + attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1) + attn_output_mot = attn_output_mot * gate_msa_mot_ref + attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot) + + # 3. cross-attention and feed-forward + x = x + wan_block.cross_attn(wan_block.norm3(x), context) + input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp) + x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x)) + + x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot) + # modulate + norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot) + norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1) + input_x_mot = self.ffn(norm_x_mot_ref) + # gate + input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1) + input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref + input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + input_x_mot).type_as(x_mot) + + return x, x_mot + + +class MotWanModel(torch.nn.Module): + def __init__( + self, + mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + patch_size=(1, 2, 2), + has_image_input=True, + has_image_pos_emb=False, + dim=5120, + num_heads=40, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + in_dim=36, + eps=1e-6, + ): + super().__init__() + self.mot_layers = mot_layers + self.freq_dim = freq_dim + self.dim = dim + + self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)} + self.head_dim = dim // num_heads + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) + + # mot blocks + self.blocks = torch.nn.ModuleList([ + MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.mot_layers + ]) + + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + return x + + def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0): + def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta) + h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + + freqs = torch.cat([ + f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1), + h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1), + w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1) + return freqs + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id): + block = self.blocks[self.mot_layers_mapping[block_id]] + x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot) + return x, x_mot diff --git a/diffsynth/models/wan_video_motion_controller.py b/diffsynth/models/wan_video_motion_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..34763a8d76e57bc8efff84f23863938cc2309029 --- /dev/null +++ b/diffsynth/models/wan_video_motion_controller.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +from .wan_video_dit import sinusoidal_embedding_1d + + + +class WanMotionControllerModel(torch.nn.Module): + def __init__(self, freq_dim=256, dim=1536): + super().__init__() + self.freq_dim = freq_dim + self.linear = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6), + ) + + def forward(self, motion_bucket_id): + emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10) + emb = self.linear(emb) + return emb + + def init(self): + state_dict = self.linear[-1].state_dict() + state_dict = {i: state_dict[i] * 0 for i in state_dict} + self.linear[-1].load_state_dict(state_dict) diff --git a/diffsynth/models/wan_video_text_encoder.py b/diffsynth/models/wan_video_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..64090db8c65138abfdb60a822b3ba2e74fefeb4c --- /dev/null +++ b/diffsynth/models/wan_video_text_encoder.py @@ -0,0 +1,330 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoTokenizer +import ftfy +import html +import string +import regex as re + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class WanTextEncoder(torch.nn.Module): + + def __init__(self, + vocab=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1): + super(WanTextEncoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text \ No newline at end of file diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..f3367f788891cb22b8a5bf6eaa50cabbc202ab4a --- /dev/null +++ b/diffsynth/models/wan_video_vace.py @@ -0,0 +1,87 @@ +import torch +from .wan_video_dit import DiTBlock + + +class VaceWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = torch.nn.Linear(self.dim, self.dim) + self.after_proj = torch.nn.Linear(self.dim, self.dim) + + def forward(self, c, x, context, t_mod, freqs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, context, t_mod, freqs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class VaceWanModel(torch.nn.Module): + def __init__( + self, + vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + vace_in_dim=96, + patch_size=(1, 2, 2), + has_image_input=False, + dim=1536, + num_heads=12, + ffn_dim=8960, + eps=1e-6, + ): + super().__init__() + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # vace blocks + self.vace_blocks = torch.nn.ModuleList([ + VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size) + + def forward( + self, x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ): + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block in self.vace_blocks: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + c, x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + c, x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + c = block(c, x, context, t_mod, freqs) + hints = torch.unbind(c)[:-1] + return hints diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..d24e29d9398f95a59cbd1466542c6e7059f7c7af --- /dev/null +++ b/diffsynth/models/wan_video_vae.py @@ -0,0 +1,1382 @@ +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +CACHE_T = 2 + + +def check_is_instance(model, module_class): + if isinstance(model, module_class): + return True + if hasattr(model, "module") and isinstance(model.module, module_class): + return True + return False + + +def block_causal_mask(x, block_size): + # params + b, n, s, _, device = *x.size(), x.device + assert s % block_size == 0 + num_blocks = s // block_size + + # build mask + mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) + for i in range(num_blocks): + mask[:, :, + i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1 + return mask + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d(dim, + dim * 2, (3, 1, 1), + padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, + dim, (3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size) + return x + + +class Resample38(Resample): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super(Resample, self).__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = nn.Identity() + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute( + 0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + #attn_mask=block_causal_mask(q, block_size=h * w) + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample38(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False + ): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample38(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Encoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] if i < len(temperal_downsample) else False + ) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + + +class Decoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock(in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1)) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1)) + + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class VideoVAE_(nn.Module): + + def __init__(self, + dim=96, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +class WanVideoVAE(nn.Module): + + def __init__(self, z_dim=16): + super().__init__() + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) + self.upsampling_factor = 8 + self.z_dim = z_dim + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.clamp_(-1, 1) + return values + + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos + + + @staticmethod + def state_dict_converter(): + return WanVideoVAEStateDictConverter() + + +class WanVideoVAEStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ + + +class VideoVAE38_(VideoVAE_): + + def __init__(self, + dim=160, + z_dim=48, + dec_dim=256, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super(VideoVAE_, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + +class WanVideoVAE38(WanVideoVAE): + + def __init__(self, z_dim=48, dim=160): + super(WanVideoVAE, self).__init__() + + mean = [ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667 + ] + std = [ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) + self.upsampling_factor = 16 + self.z_dim = z_dim diff --git a/diffsynth/models/wav2vec.py b/diffsynth/models/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..8807302d815a917123b794a597fc5fe84d3394fc --- /dev/null +++ b/diffsynth/models/wav2vec.py @@ -0,0 +1,191 @@ +import math +import numpy as np +import torch +import torch.nn.functional as F + + +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) + + +class WanS2VAudioEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + from transformers import Wav2Vec2ForCTC, Wav2Vec2Config + config = { + "_name_or_path": "facebook/wav2vec2-large-xlsr-53", + "activation_dropout": 0.05, + "apply_spec_augment": True, + "architectures": ["Wav2Vec2ForCTC"], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": True, + "conv_dim": [512, 512, 512, 512, 512, 512, 512], + "conv_kernel": [10, 3, 3, 3, 3, 2, 2], + "conv_stride": [5, 2, 2, 2, 2, 2, 2], + "ctc_loss_reduction": "mean", + "ctc_zero_infinity": True, + "do_stable_layer_norm": True, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.05, + "final_dropout": 0.0, + "hidden_act": "gelu", + "hidden_dropout": 0.05, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.05, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.05, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.7.0.dev0", + "vocab_size": 33 + } + self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) + self.video_rate = 30 + + def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'): + input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device) + + # retrieve logits & take argmax + res = self.model(input_values, output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + return feat + + def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0 + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'): + audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device) + audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype) + audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)] + return audio_embeds diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..7664fc5a37a2bf071677f09888cfa5e263ce7143 --- /dev/null +++ b/diffsynth/models/z_image_dit.py @@ -0,0 +1,621 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from torch.nn import RMSNorm +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + mid_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + mid_size, + out_size, + bias=True, + ), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(torch.bfloat16)) + return t_emb + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)]) + + self.norm_q = RMSNorm(head_dim, eps=1e-5) + self.norm_k = RMSNorm(head_dim, eps=1e-5) + + def forward(self, hidden_states, freqs_cis): + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = query.unflatten(-1, (self.num_heads, -1)) + key = key.unflatten(-1, (self.num_heads, -1)) + value = value.unflatten(-1, (self.num_heads, -1)) + + # Apply Norms + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # Compute joint attention + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = self.to_out[0](hidden_states) + if len(self.to_out) > 1: # dropout + output = self.to_out[1](output) + + return output + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + q_dim=dim, + num_heads=n_heads, + head_dim=dim // n_heads, + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x) * scale_mlp, + ) + ) + else: + # Attention block + attn_out = self.attention( + self.attention_norm1(x), + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c): + scale = 1.0 + self.adaLN_modulation(c) + x = self.norm_final(x) * scale.unsqueeze(1) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageDiT(nn.Module): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock"] + + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + adaln_input = t + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device) + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + for layer in self.noise_refiner: + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, + attn_mask=x_attn_mask, + freqs_cis=x_freqs_cis, + adaln_input=adaln_input, + ) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device) + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + for layer in self.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=cap_attn_mask, + freqs_cis=cap_freqs_cis, + ) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + for layer in self.layers: + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, + attn_mask=unified_attn_mask, + freqs_cis=unified_freqs_cis, + adaln_input=adaln_input, + ) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + return x, {} diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4eba636058b6299e41bec476ac348cbacc39cc37 --- /dev/null +++ b/diffsynth/models/z_image_text_encoder.py @@ -0,0 +1,41 @@ +from transformers import Qwen3Model, Qwen3Config +import torch + + +class ZImageTextEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + config = Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }) + self.model = Qwen3Model(config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0046949b8766a087d4746f473406ac54f686bb --- /dev/null +++ b/diffsynth/pipelines/flux2_image.py @@ -0,0 +1,370 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from typing import Union, List, Optional, Tuple + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput + +from transformers import AutoProcessor +from ..models.flux2_text_encoder import Flux2TextEncoder +from ..models.flux2_dit import Flux2DiT +from ..models.flux2_vae import Flux2VAE + + +class Flux2ImagePipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: Flux2TextEncoder = None + self.dit: Flux2DiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + Flux2Unit_ShapeChecker(), + Flux2Unit_PromptEmbedder(), + Flux2Unit_NoiseInitializer(), + Flux2Unit_InputImageEmbedder(), + Flux2Unit_ImageIDs(), + ] + self.model_fn = model_fn_flux2 + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = Flux2ImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") + pipe.dit = model_pool.fetch_model("flux2_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + latents = rearrange(inputs_shared["latents"], "B (H W) C -> B C H W", H=inputs_shared["height"]//16, W=inputs_shared["width"]//16) + image = self.vae.decode(latents) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class Flux2Unit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: Flux2ImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class Flux2Unit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + + def format_text_input(self, prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + def get_mistral_3_small_prompt_embeds( + self, + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = self.format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_mistral_3_small_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + +class Flux2Unit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1) + return {"noise": noise} + + +class Flux2Unit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: Flux2ImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + input_latents = rearrange(input_latents, "B C H W -> B (H W) C") + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class Flux2Unit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("image_ids",), + ) + + def prepare_latent_ids(self, height, width): + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(1, -1, -1) + + return latent_ids + + def process(self, pipe: Flux2ImagePipeline, height, width): + image_ids = self.prepare_latent_ids(height // 16, width // 16).to(pipe.device) + return {"image_ids": image_ids} + + +def model_fn_flux2( + dit: Flux2DiT, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee5635ee16eac74ae4a3fd322d4aa4bd323893d --- /dev/null +++ b/diffsynth/pipelines/flux_image.py @@ -0,0 +1,1205 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.flux import FluxLoRALoader + +from ..models.flux_dit import FluxDiT +from ..models.flux_text_encoder_clip import FluxTextEncoderClip +from ..models.flux_text_encoder_t5 import FluxTextEncoderT5 +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.flux_value_control import MultiValueEncoder +from ..models.step1x_text_encoder import Step1xEditEmbedder +from ..core.vram.layers import AutoWrappedLinear + +class MultiControlNet(torch.nn.Module): + def __init__(self, models: list[torch.nn.Module]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + + def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs): + model = self.models[controlnet_input.controlnet_id] + res_stack, single_res_stack = model( + controlnet_conditioning=conditioning, + processor_id=controlnet_input.processor_id, + **kwargs + ) + res_stack = [res * controlnet_input.scale for res in res_stack] + single_res_stack = [res * controlnet_input.scale for res in single_res_stack] + return res_stack, single_res_stack + + def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs): + res_stack, single_res_stack = None, None + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start or progress < controlnet_input.end: + continue + res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs) + if res_stack is None: + res_stack = res_stack_ + single_res_stack = single_res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] + return res_stack, single_res_stack + + +class FluxImagePipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.1") + self.tokenizer_1: CLIPTokenizer = None + self.tokenizer_2: T5TokenizerFast = None + self.text_encoder_1: FluxTextEncoderClip = None + self.text_encoder_2: FluxTextEncoderT5 = None + self.dit: FluxDiT = None + self.vae_decoder: FluxVAEDecoder = None + self.vae_encoder: FluxVAEEncoder = None + self.controlnet = None + self.ipadapter = None + self.ipadapter_image_encoder = None + self.qwenvl = None + self.step1x_connector = None + self.nexus_gen = None + self.nexus_gen_generation_adapter = None + self.nexus_gen_editing_adapter = None + self.value_controller = None + self.infinityou_processor = None + self.image_proj_model = None + self.lora_patcher = None + self.lora_encoder = None + self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") + self.units = [ + FluxImageUnit_ShapeChecker(), + FluxImageUnit_NoiseInitializer(), + FluxImageUnit_PromptEmbedder(), + FluxImageUnit_InputImageEmbedder(), + FluxImageUnit_ImageIDs(), + FluxImageUnit_EmbeddedGuidanceEmbedder(), + FluxImageUnit_Kontext(), + FluxImageUnit_InfiniteYou(), + FluxImageUnit_ControlNet(), + FluxImageUnit_IPAdapter(), + FluxImageUnit_EntityControl(), + FluxImageUnit_NexusGen(), + FluxImageUnit_TeaCache(), + FluxImageUnit_Flex(), + FluxImageUnit_Step1x(), + FluxImageUnit_ValueControl(), + FluxImageUnit_LoRAEncode(), + ] + self.model_fn = model_fn_flux_image + self.lora_loader = FluxLoRALoader + + def enable_lora_merger(self): + if not (hasattr(self.dit, "vram_management_enabled") and getattr(self.dit, "vram_management_enabled")): + raise ValueError("DiT VRAM management is not enabled.") + if self.lora_patcher is not None: + for name, module in self.dit.named_modules(): + if isinstance(module, AutoWrappedLinear): + merger_name = name.replace(".", "___") + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), + tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), + nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), + step1x_processor_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern=""), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder_1 = model_pool.fetch_model("flux_text_encoder_clip") + pipe.text_encoder_2 = model_pool.fetch_model("flux_text_encoder_t5") + pipe.dit = model_pool.fetch_model("flux_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + if tokenizer_1_config is not None: + tokenizer_1_config.download_if_necessary() + pipe.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_config.path) + if tokenizer_2_config is not None: + tokenizer_2_config.download_if_necessary() + pipe.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_config.path) + + value_controllers = model_pool.fetch_model("flux_value_controller") + if value_controllers is not None: + pipe.value_controller = MultiValueEncoder(value_controllers) + if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"): + pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled + controlnets = model_pool.fetch_model("flux_controlnet") + if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) + pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") + pipe.ipadapter_image_encoder = model_pool.fetch_model("siglip_vision_model") + qwenvl = model_pool.fetch_model("qwen_image_text_encoder") + if qwenvl is not None: + from transformers import AutoProcessor + step1x_processor_config.download_if_necessary() + processor = AutoProcessor.from_pretrained(step1x_processor_config.path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28) + pipe.qwenvl = Step1xEditEmbedder(qwenvl, processor) + pipe.step1x_connector = model_pool.fetch_model("step1x_connector") + pipe.image_proj_model = model_pool.fetch_model("infiniteyou_image_projector") + if pipe.image_proj_model is not None: + pipe.infinityou_processor = InfinitYou(device=device) + pipe.lora_patcher = model_pool.fetch_model("flux_lora_patcher") + pipe.lora_encoder = model_pool.fetch_model("flux_lora_encoder") + pipe.nexus_gen = model_pool.fetch_model("nexus_gen_llm") + pipe.nexus_gen_generation_adapter = model_pool.fetch_model("nexus_gen_generation_adapter") + pipe.nexus_gen_editing_adapter = model_pool.fetch_model("nexus_gen_editing_adapter") + if pipe.nexus_gen is not None: + nexus_gen_processor_config.download_if_necessary() + pipe.nexus_gen.load_processor(nexus_gen_processor_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 3.5, + t5_sequence_length: int = 512, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Scheduler + sigma_shift: float = None, + # Steps + num_inference_steps: int = 30, + # local prompts + multidiffusion_prompts=(), + multidiffusion_masks=(), + multidiffusion_scales=(), + # Kontext + kontext_images: Union[list[Image.Image], Image.Image] = None, + # ControlNet + controlnet_inputs: list[ControlNetInput] = None, + # IP-Adapter + ipadapter_images: Union[list[Image.Image], Image.Image] = None, + ipadapter_scale: float = 1.0, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + eligen_enable_inpaint: bool = False, + # InfiniteYou + infinityou_id_image: Image.Image = None, + infinityou_guidance: float = 1.0, + # Flex + flex_inpaint_image: Image.Image = None, + flex_inpaint_mask: Image.Image = None, + flex_control_image: Image.Image = None, + flex_control_strength: float = 0.5, + flex_control_stop: float = 0.5, + # Value Controller + value_controller_inputs: Union[list[float], float] = None, + # Step1x + step1x_reference_image: Image.Image = None, + # NexusGen + nexus_gen_reference_image: Image.Image = None, + # LoRA Encoder + lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None, + lora_encoder_scale: float = 1.0, + # TeaCache + tea_cache_l1_thresh: float = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, + "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, + "kontext_images": kontext_images, + "controlnet_inputs": controlnet_inputs, + "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, + "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance, + "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, + "value_controller_inputs": value_controller_inputs, + "step1x_reference_image": step1x_reference_image, + "nexus_gen_reference_image": nexus_gen_reference_image, + "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale, + "tea_cache_l1_thresh": tea_cache_l1_thresh, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "progress_bar_cmd": progress_bar_cmd, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class FluxImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width"), output_params=("height", "width")) + + def process(self, pipe: FluxImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class FluxImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",)) + + def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) + return {"noise": noise} + + + +class FluxImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae_encoder']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": None} + + + +class FluxImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + input_params=("t5_sequence_length",), + output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device="cuda", + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict: + if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None: + prompt_emb, pooled_prompt_emb, text_ids = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length, + ) + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + else: + return {} + + +class FluxImageUnit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents",), output_params=("image_ids",)) + + def process(self, pipe: FluxImagePipeline, latents): + latent_image_ids = pipe.dit.prepare_image_ids(latents) + return {"image_ids": latent_image_ids} + + + +class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): + def __init__(self): + super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",)) + + def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): + guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + return {"guidance": guidance} + + + +class FluxImageUnit_Kontext(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("kontext_images", "tiled", "tile_size", "tile_stride"), + output_params=("kontext_latents", "kontext_image_ids"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): + if kontext_images is None: + return {} + if not isinstance(kontext_images, list): + kontext_images = [kontext_images] + + kontext_latents = [] + kontext_image_ids = [] + for kontext_image in kontext_images: + kontext_image = pipe.preprocess_image(kontext_image) + kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image_ids = pipe.dit.prepare_image_ids(kontext_latent) + image_ids[..., 0] = 1 + kontext_image_ids.append(image_ids) + kontext_latent = pipe.dit.patchify(kontext_latent) + kontext_latents.append(kontext_latent) + kontext_latents = torch.concat(kontext_latents, dim=1) + kontext_image_ids = torch.concat(kontext_image_ids, dim=-2) + return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids} + + + +class FluxImageUnit_ControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("controlnet_conditionings",), + onload_model_names=("vae_encoder",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if controlnet_inputs is None: + return {} + pipe.load_models_to_device(['vae_encoder']) + conditionings = [] + for controlnet_input in controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + return {"controlnet_conditionings": conditionings} + + + +class FluxImageUnit_IPAdapter(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("ipadapter_images", "ipadapter_scale"), + output_params=("ipadapter_kwargs_list",), + onload_model_names=("ipadapter_image_encoder", "ipadapter") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0) + if ipadapter_images is None: + return inputs_shared, inputs_posi, inputs_nega + if not isinstance(ipadapter_images, list): + ipadapter_images = [ipadapter_images] + + pipe.load_models_to_device(self.onload_model_names) + images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images] + images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images] + ipadapter_images = torch.cat(images, dim=0) + ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output + + inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))}) + return inputs_shared, inputs_posi, inputs_nega + + + +class FluxImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks"), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device="cuda", + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + + prompt_emb, _, _ = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length, + ) + return prompt_emb.unsqueeze(0), entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_NexusGen(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("nexus_gen_reference_image", "prompt", "latents"), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.nexus_gen is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + if inputs_shared.get("nexus_gen_reference_image", None) is None: + assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set." + embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0) + inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed) + inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype) + else: + assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set." + embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"]) + embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long) + ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long) + + inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid) + inputs_posi["text_ids"] = self.get_editing_text_ids( + inputs_shared["latents"], + embeds_grid[0][1].item(), embeds_grid[0][2].item(), + ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(), + ) + return inputs_shared, inputs_posi, inputs_nega + + + def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width): + # prepare text ids for target and reference embeddings + batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width + embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype) + + batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width + ref_embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0 + ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype) + + text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1) + return text_ids + + +class FluxImageUnit_Step1x(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("step1x_reference_image", "prompt", "negative_prompt"), + output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"), + onload_model_names=("qwenvl","vae_encoder") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): + image = inputs_shared.get("step1x_reference_image",None) + if image is None: + return inputs_shared, inputs_posi, inputs_nega + else: + pipe.load_models_to_device(self.onload_model_names) + prompt = inputs_posi["prompt"] + nega_prompt = inputs_nega["negative_prompt"] + captions = [prompt, nega_prompt] + ref_images = [image, image] + embs, masks = pipe.qwenvl(captions, ref_images) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image) + inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}) + if inputs_shared.get("cfg_scale", 1) != 1: + inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",)) + + def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): + if tea_cache_l1_thresh is None: + return {} + else: + return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)} + +class FluxImageUnit_Flex(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), + output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride): + if pipe.dit.input_dim == 196: + if flex_control_stop is None: + flex_control_stop = 1 + pipe.load_models_to_device(self.onload_model_names) + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))] + return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} + else: + return {} + + + +class FluxImageUnit_InfiniteYou(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("infinityou_id_image", "infinityou_guidance"), + output_params=("id_emb", "infinityou_guidance"), + onload_model_names=("infinityou_processor",) + ) + + def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance): + pipe.load_models_to_device("infinityou_processor") + if infinityou_id_image is not None: + return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device) + else: + return {} + + + +class FluxImageUnit_ValueControl(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params=("value_controller_inputs",), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("value_controller",) + ) + + def add_to_text_embedding(self, prompt_emb, text_ids, value_emb): + prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) + extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs): + if value_controller_inputs is None: + return {} + if not isinstance(value_controller_inputs, list): + value_controller_inputs = [value_controller_inputs] + value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device) + pipe.load_models_to_device(["value_controller"]) + value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype) + value_emb = value_emb.unsqueeze(0) + prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb) + return {"prompt_emb": prompt_emb, "text_ids": text_ids} + + + +class InfinitYou(torch.nn.Module): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__() + from facexlib.recognition import init_recognition_model + from insightface.app import FaceAnalysis + self.device = device + self.torch_dtype = torch_dtype + insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface' + self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_640.prepare(ctx_id=0, det_size=(640, 640)) + self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_320.prepare(ctx_id=0, det_size=(320, 320)) + self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_160.prepare(ctx_id=0, det_size=(160, 160)) + self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype) + + def _detect_face(self, id_image_cv2): + face_info = self.app_640.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_320.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_160.get(id_image_cv2) + return face_info + + def extract_arcface_bgr_embedding(self, in_image, landmark, device): + from insightface.utils import face_align + arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112) + arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255. + arc_face_image = 2 * arc_face_image - 1 + arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype) + face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized + return face_emb + + def prepare_infinite_you(self, model, id_image, infinityou_guidance, device): + import cv2 + if id_image is None: + return {'id_emb': None} + id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR) + face_info = self._detect_face(id_image_cv2) + if len(face_info) == 0: + raise ValueError('No face detected in the input ID image') + landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face + id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device) + id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype)) + infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype) + return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance} + + + +class FluxImageUnit_LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("lora_encoder_inputs", "lora_encoder_scale"), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("lora_encoder",) + ) + + def parse_lora_encoder_inputs(self, lora_encoder_inputs): + if not isinstance(lora_encoder_inputs, list): + lora_encoder_inputs = [lora_encoder_inputs] + lora_configs = [] + for lora_encoder_input in lora_encoder_inputs: + if isinstance(lora_encoder_input, str): + lora_encoder_input = ModelConfig(path=lora_encoder_input) + lora_encoder_input.download_if_necessary() + lora_configs.append(lora_encoder_input) + return lora_configs + + def load_lora(self, lora_config, dtype, device): + loader = FluxLoRALoader(torch_dtype=dtype, device=device) + lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device) + lora = loader.convert_state_dict(lora) + return lora + + def lora_embedding(self, pipe, lora_encoder_inputs): + lora_emb = [] + for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs): + lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device) + lora_emb.append(pipe.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb): + prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1) + extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("lora_encoder_inputs", None) is None: + return inputs_shared, inputs_posi, inputs_nega + + # Encode + pipe.load_models_to_device(["lora_encoder"]) + lora_encoder_inputs = inputs_shared["lora_encoder_inputs"] + lora_emb = self.lora_embedding(pipe, lora_encoder_inputs) + + # Scale + lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None) + if lora_encoder_scale is not None: + lora_emb = lora_emb * lora_encoder_scale + + # Add to prompt embedding + inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding( + inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb) + return inputs_shared, inputs_posi, inputs_nega + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: FluxDiT, hidden_states, conditioning): + inp = hidden_states.clone() + temb_ = conditioning.clone() + modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = hidden_states.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +class FastTileWorker: + def __init__(self): + pass + + + def build_mask(self, data, is_bound): + _, _, H, W = data.shape + h = repeat(torch.arange(H), "H -> H W", H=H, W=W) + w = repeat(torch.arange(W), "W -> H W", H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else h + 1, + pad if is_bound[1] else H - h, + pad if is_bound[2] else w + 1, + pad if is_bound[3] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "H W -> 1 H W") + return mask + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + B, C, H, W = model_input.shape + border_width = int(tile_stride*0.5) if border_width is None else border_width + weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device) + values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = H - tile_size, H + if w_ > W: w, w_ = W - tile_size, W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + # Forward + hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device) + + mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W)) + values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask + weight[:, :, hl:hr, wl:wr] += mask + values /= weight + return values + + +def model_fn_flux_image( + dit: FluxDiT, + controlnet=None, + step1x_connector=None, + latents=None, + timestep=None, + prompt_emb=None, + pooled_prompt_emb=None, + guidance=None, + text_ids=None, + image_ids=None, + kontext_latents=None, + kontext_image_ids=None, + controlnet_inputs=None, + controlnet_conditionings=None, + tiled=False, + tile_size=128, + tile_stride=64, + entity_prompt_emb=None, + entity_masks=None, + ipadapter_kwargs_list={}, + id_emb=None, + infinityou_guidance=None, + flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, + step1x_llm_embedding=None, + step1x_mask=None, + step1x_reference_latents=None, + tea_cache: TeaCache = None, + progress_id=0, + num_inference_steps=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs +): + if tiled: + def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None + return model_fn_flux_image( + dit=dit, + controlnet=controlnet, + latents=latents[:, :, hl: hr, wl: wr], + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=pooled_prompt_emb, + guidance=guidance, + text_ids=text_ids, + image_ids=None, + controlnet_inputs=controlnet_inputs, + controlnet_conditionings=tiled_controlnet_conditionings, + tiled=False, + **kwargs + ) + return FastTileWorker().tiled_forward( + flux_forward_fn, + latents, + tile_size=tile_size, + tile_stride=tile_stride, + tile_device=latents.device, + tile_dtype=latents.dtype + ) + + hidden_states = latents + + # ControlNet + if controlnet is not None and controlnet_conditionings is not None: + controlnet_extra_kwargs = { + "hidden_states": hidden_states, + "timestep": timestep, + "prompt_emb": prompt_emb, + "pooled_prompt_emb": pooled_prompt_emb, + "guidance": guidance, + "text_ids": text_ids, + "image_ids": image_ids, + "controlnet_inputs": controlnet_inputs, + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride, + "progress_id": progress_id, + "num_inference_steps": num_inference_steps, + } + if id_emb is not None: + controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype) + controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance}) + controlnet_res_stack, controlnet_single_res_stack = controlnet( + controlnet_conditionings, **controlnet_extra_kwargs + ) + + # Flex + if flex_condition is not None: + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) + + # Step1x + if step1x_llm_embedding is not None: + prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask) + text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device) + + if image_ids is None: + image_ids = dit.prepare_image_ids(hidden_states) + + conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb) + if dit.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) + + height, width = hidden_states.shape[-2:] + hidden_states = dit.patchify(hidden_states) + + # Kontext + if kontext_latents is not None: + image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, kontext_latents], dim=1) + + # Step1x + if step1x_reference_latents is not None: + step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) + step1x_reference_latents = dit.patchify(step1x_reference_latents) + image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1) + + hidden_states = dit.x_embedder(hidden_states) + + # EliGen + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1]) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) + else: + tea_cache_update = False + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None: + if kontext_latents is None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + else: + hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None: + if kontext_latents is None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + else: + hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + if tea_cache is not None: + tea_cache.store(hidden_states) + + hidden_states = dit.final_norm_out(hidden_states, conditioning) + hidden_states = dit.final_proj_out(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + + # Kontext + if kontext_latents is not None: + hidden_states = hidden_states[:, :-kontext_latents.shape[1]] + + hidden_states = dit.unpatchify(hidden_states, height, width) + + return hidden_states diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..fd59a8e81ac63d0980e9a585590056b88ed4096a --- /dev/null +++ b/diffsynth/pipelines/qwen_image.py @@ -0,0 +1,746 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.merge import merge_lora + +from ..models.qwen_image_dit import QwenImageDiT +from ..models.qwen_image_text_encoder import QwenImageTextEncoder +from ..models.qwen_image_vae import QwenImageVAE +from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel + + +class QwenImagePipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + from transformers import Qwen2Tokenizer, Qwen2VLProcessor + + self.scheduler = FlowMatchScheduler("Qwen-Image") + self.text_encoder: QwenImageTextEncoder = None + self.dit: QwenImageDiT = None + self.vae: QwenImageVAE = None + self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None + self.tokenizer: Qwen2Tokenizer = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: QwenImageImage2LoRAModel = None + self.image2lora_coarse: QwenImageImage2LoRAModel = None + self.image2lora_fine: QwenImageImage2LoRAModel = None + self.processor: Qwen2VLProcessor = None + self.in_iteration_models = ("dit", "blockwise_controlnet") + self.units = [ + QwenImageUnit_ShapeChecker(), + QwenImageUnit_NoiseInitializer(), + QwenImageUnit_InputImageEmbedder(), + QwenImageUnit_Inpaint(), + QwenImageUnit_EditImageEmbedder(), + QwenImageUnit_ContextImageEmbedder(), + QwenImageUnit_PromptEmbedder(), + QwenImageUnit_EntityControl(), + QwenImageUnit_BlockwiseControlNet(), + ] + self.model_fn = model_fn_qwen_image + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + processor_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("qwen_image_text_encoder") + pipe.dit = model_pool.fetch_model("qwen_image_dit") + pipe.vae = model_pool.fetch_model("qwen_image_vae") + pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model("qwen_image_blockwise_controlnet", index="all")) + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + from transformers import Qwen2Tokenizer + pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path) + if processor_config is not None: + processor_config.download_if_necessary() + from transformers import Qwen2VLProcessor + pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path) + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style") + pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse") + pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine") + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Inpaint + inpaint_mask: Image.Image = None, + inpaint_blur_size: int = None, + inpaint_blur_sigma: float = None, + # Shape + height: int = 1328, + width: int = 1328, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + exponential_shift_mu: float = None, + # Blockwise ControlNet + blockwise_controlnet_inputs: list[ControlNetInput] = None, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + # Qwen-Image-Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, + edit_rope_interpolation: bool = False, + # In-context control + context_image: Image.Image = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "blockwise_controlnet_inputs": blockwise_controlnet_inputs, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, + "context_image": context_image, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class QwenImageBlockwiseMultiControlNet(torch.nn.Module): + def __init__(self, models: list[QwenImageBlockWiseControlNet]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + for model in models: + if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"): + self.vram_management_enabled = True + + def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs): + processed_conditionings = [] + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning) + processed_conditionings.append(model_output) + return processed_conditionings + + def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs): + res = 0 + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4): + continue + model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id) + res = res + model_output * controlnet_input.scale + return res + + +class QwenImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: QwenImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class QwenImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + + +class QwenImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + + +class QwenImageUnit_Inpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"), + output_params=("inpaint_mask",), + ) + + def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma): + if inpaint_mask is None: + return {} + inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1) + inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) + if inpaint_blur_size is not None and inpaint_blur_sigma is not None: + from torchvision.transforms import GaussianBlur + blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma) + inpaint_mask = blur(inpaint_mask) + return {"inpaint_mask": inpaint_mask} + + +class QwenImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + input_params=("edit_image",), + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def calculate_dimensions(self, target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + def resize_image(self, image, target_area=384*384): + width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1]) + return image.resize((width, height)) + + def encode_prompt(self, pipe: QwenImagePipeline, prompt): + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + if model_inputs.input_ids.shape[1] >= 1024: + print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image): + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))]) + txt = [template.format(base_img_prompt + e) for e in prompt] + edit_image = [self.resize_image(image) for image in edit_image] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: + pipe.load_models_to_device(self.onload_model_names) + if pipe.text_encoder is not None: + prompt = [prompt] + if edit_image is None: + split_hidden_states = self.encode_prompt(pipe, prompt) + elif isinstance(edit_image, Image.Image): + split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image) + else: + split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + +class QwenImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict: + if pipe.text_encoder is not None: + prompt = [prompt] + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] + + split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + prompt_embs, prompt_emb_masks = [], [] + for entity_prompt in entity_prompts: + prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt) + prompt_embs.append(prompt_emb_dict['prompt_emb']) + prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask']) + return prompt_embs, prompt_emb_masks, entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi) + entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + + +class QwenImageUnit_BlockwiseControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("blockwise_controlnet_conditioning",), + onload_model_names=("vae",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if blockwise_controlnet_inputs is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + conditionings = [] + for controlnet_input in blockwise_controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + + return {"blockwise_controlnet_conditioning": conditionings} + + +class QwenImageUnit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"), + output_params=("edit_latents", "edit_image"), + onload_model_names=("vae",) + ) + + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return edit_image.resize((calculated_width, calculated_height)) + + + def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_image, Image.Image): + resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image + edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) + edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + else: + resized_edit_image, edit_latents = [], [] + for image in edit_image: + if edit_image_auto_resize: + image = self.edit_image_auto_resize(image) + resized_edit_image.append(image) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + edit_latents.append(latents) + return {"edit_latents": edit_latents, "edit_image": resized_edit_image} + + +class QwenImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8) + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + prompt = [prompt] + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return prompt_embeds.view(1, -1) + + def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False): + pipe.load_models_to_device(["text_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) if highres else self.processor_lowres(image) + embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + residual = None + residual_highres = None + if pipe.image2lora_coarse is not None: + residual = self.encode_images_using_qwenvl(pipe, images, highres=False) + if pipe.image2lora_fine is not None: + residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True) + return x, residual, residual_highres + + def process(self, pipe: QwenImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x, residual, residual_highres = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres} + + +class QwenImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + output_params=("lora",), + onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"), + ) + + def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + if pipe.image2lora_coarse is not None: + pipe.load_models_to_device(["image2lora_coarse"]) + for x, residual in zip(image2lora_x, image2lora_residual): + loras.append(pipe.image2lora_coarse(x=x, residual=residual)) + if pipe.image2lora_fine is not None: + pipe.load_models_to_device(["image2lora_fine"]) + for x, residual in zip(image2lora_x, image2lora_residual_highres): + loras.append(pipe.image2lora_fine(x=x, residual=residual)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +class QwenImageUnit_ContextImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("context_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride): + if context_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype) + context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"context_latents": context_latents} + + +def model_fn_qwen_image( + dit: QwenImageDiT = None, + blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + blockwise_controlnet_conditioning=None, + blockwise_controlnet_inputs=None, + progress_id=0, + num_inference_steps=1, + entity_prompt_emb=None, + entity_prompt_emb_mask=None, + entity_masks=None, + edit_latents=None, + context_latents=None, + enable_fp8_attention=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + edit_rope_interpolation=False, + **kwargs +): + img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + timestep = timestep / 1000 + + image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image_seq_len = image.shape[1] + + if context_latents is not None: + img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)] + context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2) + image = torch.cat([image, context_image], dim=1) + if edit_latents is not None: + edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents] + img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] + edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] + image = torch.cat([image] + edit_image, dim=1) + + image = dit.img_in(image) + conditioning = dit.time_text_embed(timestep, image.dtype) + + if entity_prompt_emb is not None: + text, image_rotary_emb, attention_mask = dit.process_entity_masks( + latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, + entity_masks, height, width, image, img_shapes, + ) + else: + text = dit.txt_in(dit.txt_norm(prompt_emb)) + if edit_rope_interpolation: + image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device) + else: + image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + attention_mask = None + + if blockwise_controlnet_conditioning is not None: + blockwise_controlnet_conditioning = blockwise_controlnet.preprocess( + blockwise_controlnet_inputs, blockwise_controlnet_conditioning) + + for block_id, block in enumerate(dit.transformer_blocks): + text, image = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, + ) + if blockwise_controlnet_conditioning is not None: + image_slice = image[:, :image_seq_len].clone() + controlnet_output = blockwise_controlnet.blockwise_forward( + image=image_slice, conditionings=blockwise_controlnet_conditioning, + controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id, + progress_id=progress_id, num_inference_steps=num_inference_steps, + ) + image[:, :image_seq_len] = image_slice + controlnet_output + + image = dit.norm_out(image, conditioning) + image = dit.proj_out(image) + image = image[:, :image_seq_len] + + latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) + return latents diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..fa43db1dd893ac5c9271c80eb9589c6d16e9b66e --- /dev/null +++ b/diffsynth/pipelines/wan_video.py @@ -0,0 +1,1517 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import Wav2Vec2Processor + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d +from ..models.wan_video_dit_s2v import rope_precompute +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel +from ..models.wav2vec import WanS2VAudioEncoder +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.audio_processor: Wav2Vec2Processor = None + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None + self.vap: MotWanModel = None + self.animate_adapter: WanAnimateAdapter = None + self.audio_encoder: WanS2VAudioEncoder = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_AnimateVideoSplit(), + WanVideoUnit_AnimatePoseLatents(), + WanVideoUnit_AnimateFacePixelValues(), + WanVideoUnit_AnimateInpaint(), + WanVideoUnit_VAP(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), + ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] + self.model_fn = model_fn_wan_video + + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), + "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), + "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] + model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp() + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit + pipe.vae = model_pool.fetch_model("wan_video_vae") + pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") + vace = model_pool.fetch_model("wan_video_vace", index=2) + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace + pipe.vap = model_pool.fetch_model("wan_video_vap") + pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + if audio_processor_config is not None: + audio_processor_config.download_if_necessary() + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "vap_prompt": vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + models["vace"] = self.vace2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # VACE (TODO: remove it) + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): + if vace_reference_image is not None and isinstance(vace_reference_image, list): + f = len(vace_reference_image) + else: + f = 1 + inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if vace_reference_image is not None: + if not isinstance(vace_reference_image, list): + vace_reference_image = [vace_reference_image] + vace_reference_image = pipe.preprocess_video(vace_reference_image) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), + output_params=("clip_feature", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae", "image_encoder") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("motion_bucket_id",), + output_params=("motion_bucket_id",) + ) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + output_params=("vace_context", "vace_scale"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + + vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder", "vae", "image_encoder"), + input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") + vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) + negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae", "image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + output_params=("tea_cache",) + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae",), + input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), + output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} + pipe.load_models_to_device(["audio_encoder"]) + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): + pipe.load_models_to_device(["vae"]) + motion_frames = 73 + kwargs = {} + if motion_video is not None and len(motion_video) > 0: + assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" + motion_latents = pipe.preprocess_video(motion_video) + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + kwargs.update({"motion_latents": motion_latents}) + return kwargs + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} + if s2v_pose_video is None: + return {"s2v_pose_latents": None} + pipe.load_models_to_device(["vae"]) + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] + # pad if not enough frames + padding_frames = infer_frames * num_repeats - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) + return inputs_shared, inputs_posi, inputs_nega + + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + + +class WanVideoUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), + output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") + ) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + output_params=("pose_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("animate_face_video",), + output_params=("face_pixel_values"), + ) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return inputs_shared, inputs_posi, inputs_nega + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + output_params=("longcat_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + vap: MotWanModel = None, + animate_adapter: WanAnimateAdapter = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + audio_embeds: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, + drop_motion_frames: bool = True, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, + longcat_latents=None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + # LongCat-Video + if isinstance(dit, LongCatVideoTransformer3DModel): + return model_fn_longcat_video( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + longcat_latents=longcat_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + # wan2.2 s2v + if audio_embeds is not None: + return model_fn_wans2v( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + audio_embeds=audio_embeds, + motion_latents=motion_latents, + s2v_pose_latents=s2v_pose_latents, + drop_motion_frames=drop_motion_frames, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Camera control + x = dit.patchify(x, control_camera_latents_input) + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + + # Patchify + f, h, w = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # VAP + if vap is not None: + # hidden state + x_vap = vap_hidden_state + x_vap = vap.patchify(x_vap) + x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() + # Timestep + clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) + t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) + t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) + + # rope + freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) + + # context + vap_clip_embedding = vap.img_emb(vap_clip_feature) + context_vap = vap.text_embedding(context_vap) + context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_vap(block, vap): + def custom_forward(*inputs): + return vap(block, *inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + # Block + if vap is not None and block_id in vap.mot_layers_mapping: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + else: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + else: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + # VACE + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) + x = x + current_vace_hint * vace_scale + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x + + +def model_fn_longcat_video( + dit: LongCatVideoTransformer3DModel, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + longcat_latents: torch.Tensor = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, +): + if longcat_latents is not None: + latents[:, :, :longcat_latents.shape[2]] = longcat_latents + num_cond_latents = longcat_latents.shape[2] + else: + num_cond_latents = 0 + context = context.unsqueeze(0) + encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) + output = dit( + latents, + timestep, + context, + encoder_attention_mask, + num_cond_latents=num_cond_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + output = -output + output = output.to(latents.dtype) + return output + + +def model_fn_wans2v( + dit, + latents, + timestep, + context, + audio_embeds, + motion_latents, + s2v_pose_latents, + drop_motion_frames=True, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, +): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) + + # x and s2v_pose_latents + s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel + + # reference image + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) + + x = x + dit.trainable_cond_mask(mask).to(x.dtype) + + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] + x = dit.head(x, t[:-1]) + x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/diffsynth/pipelines/wan_video_comp_attn.py b/diffsynth/pipelines/wan_video_comp_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..ea691d5f98df50d944677415f806b46fdd66a12a --- /dev/null +++ b/diffsynth/pipelines/wan_video_comp_attn.py @@ -0,0 +1,108 @@ +from typing import Optional + +from .wan_video import WanVideoPipeline, ModelConfig +from ..models.comp_attn_model import ( + CompAttnConfig, + CompAttnMergeUnit, + CompAttnUnit, + patch_cross_attention, + wrap_model_fn, +) +from .wan_video import WanVideoUnit_PromptEmbedder, WanVideoUnit_CfgMerger + + +def attach_comp_attn(pipe: WanVideoPipeline) -> WanVideoPipeline: + if getattr(pipe, "_comp_attn_attached", False): + return pipe + prompt_idx = None + cfg_idx = None + for idx, unit in enumerate(pipe.units): + if prompt_idx is None and isinstance(unit, WanVideoUnit_PromptEmbedder): + prompt_idx = idx + if cfg_idx is None and isinstance(unit, WanVideoUnit_CfgMerger): + cfg_idx = idx + if prompt_idx is not None: + pipe.units.insert(prompt_idx + 1, CompAttnUnit()) + else: + pipe.units.append(CompAttnUnit()) + if cfg_idx is not None: + pipe.units.insert(cfg_idx + 1, CompAttnMergeUnit()) + else: + pipe.units.append(CompAttnMergeUnit()) + patch_cross_attention(pipe) + wrap_model_fn(pipe) + pipe._comp_attn_attached = True + return pipe + + +class WanVideoCompAttnPipeline: + """Comp-Attn 增强的视频生成 Pipeline + + 支持两种标注模式: + + 1. 显式标记模式(推荐): + 在 prompt 中使用 subject 标记,索引与 bboxes 对应 + + ```python + prompt = "A <0>red car drives left, a <1>blue bicycle rides right" + comp_attn = CompAttnConfig( + bboxes=[car_bboxes, bike_bboxes], # 按标记索引 <0>, <1> 对应 + ) + ``` + + 2. 隐式搜索模式(兼容旧版): + 提供 subjects 列表,自动在 prompt 中搜索匹配 + + ```python + prompt = "A red car drives left, a blue bicycle rides right" + comp_attn = CompAttnConfig( + subjects=["red car", "blue bicycle"], + bboxes=[car_bboxes, bike_bboxes], + ) + ``` + """ + + def __init__(self, pipe: WanVideoPipeline): + self.pipe = attach_comp_attn(pipe) + + def __getattr__(self, name): + return getattr(self.pipe, name) + + @staticmethod + def from_pretrained( + torch_dtype=None, + device="cuda", + model_configs: list[ModelConfig] = None, + tokenizer_config: Optional[ModelConfig] = None, + audio_processor_config: Optional[ModelConfig] = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + ): + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch_dtype, + device=device, + model_configs=model_configs or [], + tokenizer_config=tokenizer_config, + audio_processor_config=audio_processor_config, + redirect_common_files=redirect_common_files, + use_usp=use_usp, + vram_limit=vram_limit, + ) + return WanVideoCompAttnPipeline(pipe) + + def __call__( + self, + prompt: str, + negative_prompt: str = "", + comp_attn: Optional[CompAttnConfig] = None, + **kwargs, + ): + num_frames = kwargs.get("num_frames") + if num_frames is not None: + self.pipe._comp_attn_num_frames = num_frames + + self.pipe._comp_attn_config = comp_attn + self.pipe._comp_attn_last_prompt = prompt + self.pipe._comp_attn_last_negative_prompt = negative_prompt + return self.pipe(prompt=prompt, negative_prompt=negative_prompt, **kwargs) diff --git a/diffsynth/pipelines/wan_video_instanceV.py b/diffsynth/pipelines/wan_video_instanceV.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbd42ba69c6c2fcf8a7398743a2c77e97578329 --- /dev/null +++ b/diffsynth/pipelines/wan_video_instanceV.py @@ -0,0 +1,1825 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import Wav2Vec2Processor + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit_instancev import WanModel, sinusoidal_embedding_1d, apply_saug +from ..models.wan_video_dit_s2v import rope_precompute +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel +from ..models.wav2vec import WanS2VAudioEncoder +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel +import math + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.audio_processor: Wav2Vec2Processor = None + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None + self.vap: MotWanModel = None + self.animate_adapter: WanAnimateAdapter = None + self.audio_encoder: WanS2VAudioEncoder = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_InstanceV(), # <==== InstanceV: 需要在 latents 创建后运行 + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_AnimateVideoSplit(), + WanVideoUnit_AnimatePoseLatents(), + WanVideoUnit_AnimateFacePixelValues(), + WanVideoUnit_AnimateInpaint(), + WanVideoUnit_VAP(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), + ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] + self.model_fn = model_fn_wan_video + + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), + "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), + "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] + model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp() + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit + pipe.vae = model_pool.fetch_model("wan_video_vae") + pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") + vace = model_pool.fetch_model("wan_video_vace", index=2) + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace + pipe.vap = model_pool.fetch_model("wan_video_vap") + pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + if audio_processor_config is not None: + audio_processor_config.download_if_necessary() + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + + # ===== InstanceV (new) ===== + instance_prompts: Optional[list[str]] = None, + # 二选一:给 masks(更精确)或给 bboxes(更方便) + instance_masks: Optional[list] = None, # 见下:支持 [Nins] 或 [Nins][num_frames] 的 PIL mask + instance_bboxes: Optional[list] = None, # 见下:支持 [num_frames][Nins] 的 (x0,y0,x1,y1) + empty_instance_token: str = "", + + # SAUG (可选) + saug_scale: float = 0.0, # 论文的 w + saug_drop_prob: float = 0.0, # 训练里用的 dropout,推理一般 0 + + + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "vap_prompt": vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, + # ===== InstanceV (new) ===== + "instance_prompts": instance_prompts, + "instance_masks": instance_masks, + "instance_bboxes": instance_bboxes, + "empty_instance_token": empty_instance_token, + "saug_scale": saug_scale, + "saug_drop_prob": saug_drop_prob, + + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + models["vace"] = self.vace2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # VACE (TODO: remove it) + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): + if vace_reference_image is not None and isinstance(vace_reference_image, list): + f = len(vace_reference_image) + else: + f = 1 + inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video + + + + +class WanVideoUnit_InstanceV(PipelineUnit): + """ + 产出: + - instance_prompt_tokens: (B=1, F_lat, Nins, D_text) + - empty_instance_prompt_tokens: (B=1, F_lat, Nins, D_text) + - instance_attn_mask: (B=1, F_tok, Nins, HW_tok) bool + """ + def __init__(self): + super().__init__( + input_params=( + "instance_prompts", "instance_masks", "instance_bboxes", + "empty_instance_token", "height", "width", "num_frames", + "latents", "saug_scale", "saug_drop_prob", + ), + output_params=( + "instance_prompt_tokens", "empty_instance_prompt_tokens", + "instance_attn_mask", "saug_scale", "saug_drop_prob", + ), + ) + + @staticmethod + def _pool_text_emb(token_emb: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + # token_emb: (1, L, D), attn_mask: (1, L) + m = attn_mask.to(token_emb.dtype).unsqueeze(-1) # (1,L,1) + denom = m.sum(dim=1).clamp(min=1.0) + return (token_emb * m).sum(dim=1) / denom # (1,D) + + def _encode_one(self, pipe, text: str) -> torch.Tensor: + # 返回 (1, D_text) + ids, mask = pipe.tokenizer(text, return_mask=True, return_token_type_ids=False) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + out = pipe.text_encoder(ids, mask).to(torch.float32) # (1,L,D) + return self._pool_text_emb(out, mask) # (1,D) + + def _downsample_time(self, items, target_len: int): + # items: list length num_frames 或 length_lat + if items is None: + return None + if len(items) == target_len: + return items + # 如果给的是原始 num_frames(比如 81),按 VAE time stride=4 的习惯取 [0,4,8,...] + idx = [min(i*4, len(items)-1) for i in range(target_len)] + return [items[i] for i in idx] + + def _build_mask_from_bboxes(self, height, width, f_lat, bboxes, pipe): + # bboxes: [f_lat][Nins] (x0,y0,x1,y1) 像素坐标 + ps_t, ps_h, ps_w = pipe.dit.patch_size + f_tok = f_lat // ps_t + # 空间 token 网格:先 VAE /8,再 patch /ps_h, /ps_w + h_lat = height // pipe.vae.upsampling_factor + w_lat = width // pipe.vae.upsampling_factor + h_tok = h_lat // ps_h + w_tok = w_lat // ps_w + hw_tok = h_tok * w_tok + + bboxes_tok = self._downsample_time(bboxes, f_tok) + + nins = len(bboxes_tok[0]) if (bboxes_tok and bboxes_tok[0]) else 0 + mask = torch.zeros((1, f_tok, nins, hw_tok), dtype=torch.bool, device=pipe.device) + + for t in range(f_tok): + for j in range(nins): + box = bboxes_tok[t][j] + if box is None: + continue + x0, y0, x1, y1 = box + # 映射到 token 坐标 + tx0 = int(math.floor(x0 * w_tok / width)) + tx1 = int(math.ceil (x1 * w_tok / width)) + ty0 = int(math.floor(y0 * h_tok / height)) + ty1 = int(math.ceil (y1 * h_tok / height)) + tx0, tx1 = max(0, tx0), min(w_tok, tx1) + ty0, ty1 = max(0, ty0), min(h_tok, ty1) + if tx1 <= tx0 or ty1 <= ty0: + continue + grid = torch.zeros((h_tok, w_tok), dtype=torch.bool, device=pipe.device) + grid[ty0:ty1, tx0:tx1] = True + mask[0, t, j] = grid.flatten() + + return mask + + def _build_mask_from_masks(self, height, width, f_lat, masks, pipe): + # masks: [Nins] 或 [Nins][num_frames/length_lat] + ps_t, ps_h, ps_w = pipe.dit.patch_size + f_tok = f_lat // ps_t + h_lat = height // pipe.vae.upsampling_factor + w_lat = width // pipe.vae.upsampling_factor + h_tok = h_lat // ps_h + w_tok = w_lat // ps_w + hw_tok = h_tok * w_tok + + # 统一成 [Nins][f_tok] + if len(masks) > 0 and isinstance(masks[0], Image.Image): + masks = [[m]*f_tok for m in masks] + else: + masks = [self._downsample_time(m_list, f_tok) for m_list in masks] + + nins = len(masks) + out = torch.zeros((1, f_tok, nins, hw_tok), dtype=torch.bool, device=pipe.device) + + for j in range(nins): + for t in range(f_tok): + m = masks[j][t].convert("L") + # resize 到 token 网格 + m = m.resize((w_tok, h_tok), resample=Image.NEAREST) + arr = (np.array(m) > 127) + out[0, t, j] = torch.from_numpy(arr.reshape(-1)).to(device=pipe.device) + + return out + + def process( + self, pipe, + instance_prompts, instance_masks, instance_bboxes, + empty_instance_token, height, width, num_frames, + latents, saug_scale, saug_drop_prob, + ): + if instance_prompts is None or len(instance_prompts) == 0: + return {} + + # f_lat:注意这里是 latent 的时间长度(NoiseInitializer 里 length=(num_frames-1)//4+1):contentReference[oaicite:4]{index=4} + f_lat = latents.shape[2] + + # 获取目标 dtype(与 dit 权重一致) + target_dtype = getattr(pipe, 'torch_dtype', torch.bfloat16) + + # 1) instance prompt -> (1, f_lat, Nins, D_text) -> project to D_model + inst_vecs = [self._encode_one(pipe, p) for p in instance_prompts] # list of (1,D_text) + inst_vecs = torch.cat(inst_vecs, dim=0).unsqueeze(0) # (1,Nins,D_text) + + # 使用 dit.text_embedding 将 D_text (4096) 投影到 D_model (1536) + # text_embedding 期望 (B, N, D_text),输出 (B, N, D_model) + inst_vecs = inst_vecs.to(dtype=target_dtype) # 转换 dtype + inst_vecs_proj = pipe.dit.text_embedding(inst_vecs) # (1,Nins,D_model) + inst_tokens = inst_vecs_proj.unsqueeze(1).repeat(1, f_lat, 1, 1) # (1,F_lat,Nins,D_model) + + # 2) empty instance prompts(论文说不能用空字符串,不然会塌缩;建议用 这类特殊 token) + # 做法:用 " ... " 生成彼此不同的"空语义但可区分"embedding + nins = len(instance_prompts) + empty_vecs = [] + for j in range(nins): + tok = f"" + empty_vecs.append(self._encode_one(pipe, tok)) + empty_vecs = torch.cat(empty_vecs, dim=0).unsqueeze(0) # (1,Nins,D_text) + empty_vecs = empty_vecs.to(dtype=target_dtype) # 转换 dtype + empty_vecs_proj = pipe.dit.text_embedding(empty_vecs) # (1,Nins,D_model) + empty_tokens = empty_vecs_proj.unsqueeze(1).repeat(1, f_lat, 1, 1) # (1,F_lat,Nins,D_model) + + # 3) mask / bbox -> (1, F_tok, Nins, HW_tok) + if instance_masks is not None: + attn_mask = self._build_mask_from_masks(height, width, f_lat, instance_masks, pipe) + elif instance_bboxes is not None: + attn_mask = self._build_mask_from_bboxes(height, width, f_lat, instance_bboxes, pipe) + else: + # 没有空间约束,就当全图都属于该 instance(不推荐,但能跑) + ps_t, ps_h, ps_w = pipe.dit.patch_size + f_tok = f_lat // ps_t + h_lat = height // pipe.vae.upsampling_factor + w_lat = width // pipe.vae.upsampling_factor + h_tok = h_lat // ps_h + w_tok = w_lat // ps_w + hw_tok = h_tok * w_tok + attn_mask = torch.ones((1, f_tok, nins, hw_tok), dtype=torch.bool, device=pipe.device) + + # inst_tokens 和 empty_tokens 已经是正确的 dtype(通过 text_embedding 处理) + return { + "instance_prompt_tokens": inst_tokens, + "empty_instance_prompt_tokens": empty_tokens, + "instance_attn_mask": attn_mask, + "saug_scale": float(saug_scale), + "saug_drop_prob": float(saug_drop_prob), + } + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if vace_reference_image is not None: + if not isinstance(vace_reference_image, list): + vace_reference_image = [vace_reference_image] + vace_reference_image = pipe.preprocess_video(vace_reference_image) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), + output_params=("clip_feature", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae", "image_encoder") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("motion_bucket_id",), + output_params=("motion_bucket_id",) + ) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + output_params=("vace_context", "vace_scale"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + + vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder", "vae", "image_encoder"), + input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") + vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) + negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae", "image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + output_params=("tea_cache",) + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae",), + input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), + output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} + pipe.load_models_to_device(["audio_encoder"]) + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): + pipe.load_models_to_device(["vae"]) + motion_frames = 73 + kwargs = {} + if motion_video is not None and len(motion_video) > 0: + assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" + motion_latents = pipe.preprocess_video(motion_video) + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + kwargs.update({"motion_latents": motion_latents}) + return kwargs + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} + if s2v_pose_video is None: + return {"s2v_pose_latents": None} + pipe.load_models_to_device(["vae"]) + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] + # pad if not enough frames + padding_frames = infer_frames * num_repeats - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) + return inputs_shared, inputs_posi, inputs_nega + + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + + +class WanVideoUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), + output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") + ) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + output_params=("pose_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("animate_face_video",), + output_params=("face_pixel_values"), + ) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return inputs_shared, inputs_posi, inputs_nega + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + output_params=("longcat_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + vap: MotWanModel = None, + animate_adapter: WanAnimateAdapter = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + audio_embeds: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, + drop_motion_frames: bool = True, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, + longcat_latents=None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + + # ===== InstanceV (new) ===== + instance_prompt_tokens: Optional[torch.Tensor] = None, # (1,F_lat,Nins,D_text) + empty_instance_prompt_tokens: Optional[torch.Tensor] = None, # (1,F_lat,Nins,D_text) + instance_attn_mask: Optional[torch.Tensor] = None, # (1,F_tok,Nins,HW_tok) + saug_scale: float = 0.0, + saug_drop_prob: float = 0.0, + _skip_saug: bool = False, + + + **kwargs, +): + if ( + not _skip_saug + and saug_scale + and instance_prompt_tokens is not None + and empty_instance_prompt_tokens is not None + and instance_attn_mask is not None + ): + common_args = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + vap=vap, + animate_adapter=animate_adapter, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + audio_embeds=audio_embeds, + motion_latents=motion_latents, + s2v_pose_latents=s2v_pose_latents, + vap_hidden_state=vap_hidden_state, + vap_clip_feature=vap_clip_feature, + context_vap=context_vap, + drop_motion_frames=drop_motion_frames, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + pose_latents=pose_latents, + face_pixel_values=face_pixel_values, + longcat_latents=longcat_latents, + sliding_window_size=sliding_window_size, + sliding_window_stride=sliding_window_stride, + cfg_merge=cfg_merge, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + control_camera_latents_input=control_camera_latents_input, + fuse_vae_embedding_in_latents=fuse_vae_embedding_in_latents, + instance_attn_mask=instance_attn_mask, + **kwargs, + ) + noise_pred_cond = model_fn_wan_video( + instance_prompt_tokens=instance_prompt_tokens, + empty_instance_prompt_tokens=empty_instance_prompt_tokens, + saug_scale=0.0, + saug_drop_prob=saug_drop_prob, + _skip_saug=True, + **common_args, + ) + noise_pred_uncond = model_fn_wan_video( + instance_prompt_tokens=empty_instance_prompt_tokens, + empty_instance_prompt_tokens=empty_instance_prompt_tokens, + saug_scale=0.0, + saug_drop_prob=0.0, + _skip_saug=True, + **common_args, + ) + return apply_saug(noise_pred_cond, noise_pred_uncond, float(saug_scale)) + + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + # LongCat-Video + if isinstance(dit, LongCatVideoTransformer3DModel): + return model_fn_longcat_video( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + longcat_latents=longcat_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + # wan2.2 s2v + if audio_embeds is not None: + return model_fn_wans2v( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + audio_embeds=audio_embeds, + motion_latents=motion_latents, + s2v_pose_latents=s2v_pose_latents, + drop_motion_frames=drop_motion_frames, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Camera control + x = dit.patchify(x, control_camera_latents_input) + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + + # Patchify + f, h, w = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # VAP + if vap is not None: + # hidden state + x_vap = vap_hidden_state + x_vap = vap.patchify(x_vap) + x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() + # Timestep + clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) + t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) + t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) + + # rope + freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) + + # context + vap_clip_embedding = vap.img_emb(vap_clip_feature) + context_vap = vap.text_embedding(context_vap) + context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_vap(block, vap): + def custom_forward(*inputs): + return vap(block, *inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + # Block + if vap is not None and block_id in vap.mot_layers_mapping: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + else: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + else: + # ===== InstanceV (new) ===== + # STAPE 和 IMCA 都在 DiTBlock 内部处理,这里只需传递参数 + use_instancev = ( + instance_prompt_tokens is not None + and instance_attn_mask is not None + and hasattr(block, "imca") and block.imca is not None + ) + + # 创建包含 InstanceV 参数的 forward 函数 + def create_instancev_forward(module, use_iv, inst_tok, inst_mask, empty_inst_tok, saug_p): + def custom_forward(x, context, t_mod, freqs): + if use_iv: + return module( + x, context, t_mod, freqs, + instance_tokens=inst_tok, + instance_attn_mask=inst_mask, + empty_instance_tokens=empty_inst_tok, + saug_drop_prob=saug_p, + ) + else: + return module(x, context, t_mod, freqs) + return custom_forward + + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_instancev_forward( + block, use_instancev, instance_prompt_tokens, + instance_attn_mask, empty_instance_prompt_tokens, saug_drop_prob + ), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_instancev_forward( + block, use_instancev, instance_prompt_tokens, + instance_attn_mask, empty_instance_prompt_tokens, saug_drop_prob + ), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if use_instancev: + x = block( + x, context, t_mod, freqs, + instance_tokens=instance_prompt_tokens, + instance_attn_mask=instance_attn_mask, + empty_instance_tokens=empty_instance_prompt_tokens, + saug_drop_prob=saug_drop_prob, + ) + else: + x = block(x, context, t_mod, freqs) + + + # VACE + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) + x = x + current_vace_hint * vace_scale + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x + + +def model_fn_longcat_video( + dit: LongCatVideoTransformer3DModel, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + longcat_latents: torch.Tensor = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, +): + if longcat_latents is not None: + latents[:, :, :longcat_latents.shape[2]] = longcat_latents + num_cond_latents = longcat_latents.shape[2] + else: + num_cond_latents = 0 + context = context.unsqueeze(0) + encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) + output = dit( + latents, + timestep, + context, + encoder_attention_mask, + num_cond_latents=num_cond_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + output = -output + output = output.to(latents.dtype) + return output + + +def model_fn_wans2v( + dit, + latents, + timestep, + context, + audio_embeds, + motion_latents, + s2v_pose_latents, + drop_motion_frames=True, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, +): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) + + # x and s2v_pose_latents + s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel + + # reference image + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) + + x = x + dit.trainable_cond_mask(mask).to(x.dtype) + + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] + x = dit.head(x, t[:-1]) + x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/diffsynth/pipelines/wan_video_mvid.py b/diffsynth/pipelines/wan_video_mvid.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6f43a298a10e4187764e9578872aac8481726a --- /dev/null +++ b/diffsynth/pipelines/wan_video_mvid.py @@ -0,0 +1,1487 @@ +import torch, warnings, glob, os, types +import numpy as np +from PIL import Image +from einops import repeat, reduce +from typing import Optional, Union +from dataclasses import dataclass +from modelscope import snapshot_download +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +import torch.nn.functional as F +from PIL import Image, ImageOps + +from diffsynth.core import ModelConfig +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit, PipelineUnitRunner +from diffsynth.models import ModelManager, load_state_dict +from diffsynth.models.wan_video_dit_mvid import WanModel, RMSNorm, sinusoidal_embedding_1d +from diffsynth.models.wan_video_dit_s2v import rope_precompute +from diffsynth.models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm +from diffsynth.models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from diffsynth.models.wan_video_image_encoder import WanImageEncoder +from diffsynth.models.wan_video_vace import VaceWanModel +from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel +from diffsynth.schedulers.flow_match import FlowMatchScheduler +from diffsynth.prompters import WanPrompter +from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm +from diffsynth.lora import GeneralLoRALoader +from diffsynth.utils.data import save_video + + + +import random +from torchvision.transforms import Compose, Normalize, ToTensor + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.prompter = WanPrompter(tokenizer_path=tokenizer_path) + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") + self.unit_runner = PipelineUnitRunner() + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + # WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_VideoEmbedderFused(), + WanVideoUnit_RefEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + # WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + ] + + self.model_fn = model_fn_wan_video + + + def extrac_ref_latents(self, ref_images, vae, device, dtype, min_value=-1., max_value=1.): + # Load image. + ref_vae_latents = [] + for img in ref_images: + img = torch.Tensor(np.array(img, dtype=np.float32)) + img = img.to(dtype=dtype, device=device) + img = img * ((max_value - min_value) / 255) + min_value + img_vae_latent = vae.encode([img.permute(2,0,1).unsqueeze(1)], device=device) ###1 C 1 H W + ref_vae_latents.append(img_vae_latent) + return torch.cat(ref_vae_latents, dim=2) ###1 C ref_num H W + + + + def load_lora(self, module, path, alpha=1): + loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) + loader.load(module, lora, alpha=alpha) + + + def training_loss(self, **inputs): + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) + + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) + if inputs["ref_images_latents"] is not None: + if random.random() < inputs["args"].zero_face_ratio: + inputs["latents"] = torch.cat([inputs["latents"], torch.zeros_like(inputs['ref_images_latents'])], dim=2) + else: + inputs["latents"] = torch.cat([inputs["latents"], inputs['ref_images_latents']], dim=2) + training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) + # print(inputs["input_latents"].shape, inputs['ref_images_latents'].shape, inputs["num_ref_images"], training_target.shape) + noise_pred = self.model_fn(**inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float()[:, :, :-inputs["num_ref_images"]], training_target.float()) + loss = loss * self.scheduler.training_weight(timestep) + return loss + + + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): + self.vram_management_enabled = True + if num_persistent_param_in_dit is not None: + vram_limit = None + else: + if vram_limit is None: + vram_limit = self.get_vram() + vram_limit = vram_limit - vram_buffer + if self.text_encoder is not None: + dtype = next(iter(self.text_encoder.parameters())).dtype + enable_vram_management( + self.text_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + T5RelativeEmbedding: AutoWrappedModule, + T5LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + if self.dit is not None: + dtype = next(iter(self.dit.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.Conv1d: AutoWrappedModule, + torch.nn.Embedding: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + if self.dit2 is not None: + dtype = next(iter(self.dit2.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + if self.vae is not None: + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + RMS_norm: AutoWrappedModule, + CausalConv3d: AutoWrappedModule, + Upsample: AutoWrappedModule, + torch.nn.SiLU: AutoWrappedModule, + torch.nn.Dropout: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.image_encoder is not None: + dtype = next(iter(self.image_encoder.parameters())).dtype + enable_vram_management( + self.image_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) + if self.motion_controller is not None: + dtype = next(iter(self.motion_controller.parameters())).dtype + enable_vram_management( + self.motion_controller, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) + if self.vace is not None: + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.vace, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + + def initialize_usp(self): + import torch.distributed as dist + from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment + dist.init_process_group(backend="nccl", init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), + ) + torch.cuda.set_device(dist.get_rank()) + + + def enable_usp(self): + from xfuser.core.distributed import get_sequence_parallel_world_size + from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="/root/paddle_job/workspace/qizipeng/wanx_pretrainedmodels/Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp=False, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", + "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern] + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: pipe.initialize_usp() + + # Download and load models + model_manager = ModelManager() + for model_config in model_configs: + model_config.download_if_necessary(use_usp=use_usp) + model_manager.load_model( + model_config.path, + device=model_config.offload_device or device, + torch_dtype=model_config.offload_dtype or torch_dtype + ) + + # Load models + pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") + dit = model_manager.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit + pipe.vae = model_manager.fetch_model("wan_video_vae") + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + tokenizer_config.download_if_necessary(use_usp=use_usp) + pipe.prompter.fetch_models(pipe.text_encoder) + # pipe.prompter.fetch_tokenizer(tokenizer_config.path) + pipe.prompter.fetch_tokenizer('/root/paddlejob/workspace/qizipeng/wanx_pretrainedmodels/Wan2.2-TI2V-5B/google/umt5-xxl') + + if audio_processor_config is not None: + audio_processor_config.download_if_necessary(use_usp=use_usp) + from transformers import Wav2Vec2Processor + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + return pipe + + + @torch.no_grad() + def __call__( + self, + args, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + input_pre_video: Optional[list[Image.Image]] = None, + ref_images: Optional[list[Image.Image]] = None, + prev_latent=None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[str] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_scale_face: Optional[float] = 5.0, #### face condition negetive + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + num_ref_images: Optional[int] = None, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "input_pre_video":input_pre_video, + "ref_images":ref_images, + "control_video": control_video, "reference_image": reference_image, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, + "num_ref_images":num_ref_images, + "batch_size": 1 + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(args, **models, **inputs_shared, **inputs_posi, timestep=timestep) ## text img + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + # noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) ## O img + if 'ref_images_latents' in inputs_shared: + inputs_shared['latents'][:, :, -inputs_shared["ref_images_latents"].shape[2]:] = torch.zeros_like(inputs_shared['ref_images_latents']) + noise_pred_nega_face = self.model_fn(args, **models, **inputs_shared, **inputs_posi, timestep=timestep) # text, 0 + noise_all_eng = self.model_fn(args, **models, **inputs_shared, **inputs_nega, timestep=timestep) # 0, 0 + noise_pred = noise_all_eng + cfg_scale * (noise_pred_posi - noise_pred_nega_face) + cfg_scale_face * (noise_pred_nega_face - noise_all_eng) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + + if "ref_images_latents" in inputs_shared: + inputs_shared["latents"][:, :, -inputs_shared["ref_images_latents"].shape[2]:] = inputs_shared["ref_images_latents"] + + # if progress_id in [0,10,20,30,40,43,44,45,46,47,48,49]: + # self.load_models_to_device(['vae']) + # video = self.vae.decode(inputs_shared["latents"][:, :, :-inputs_shared["ref_images_latents"].shape[2]], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + # video = self.vae_output_to_video(video) + # save_video(video, f"./results/videos/video_wyzlarge_arrange5_step_{timestep.item()}_progress_id_{progress_id}.mp4", fps=24, quality=5) + + # VACE (TODO: remove it) + if vace_reference_image is not None: + inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] + + # Decode + if "ref_images_latents" in inputs_shared: + inputs_shared["latents"] = inputs_shared["latents"][:, :, :-inputs_shared["ref_images_latents"].shape[2]] + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video, inputs_shared["latents"] + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "num_frames")) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image", "batch_size")) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image, batch_size = 1): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + length += 1 + + shape = (batch_size, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) ### B C F H W + # shape = (batch_size, vae.model.z_dim, length, height // vae.upsampling_factor, width // vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) + + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(["vae"]) + input_latents = [] + for input_video_ in input_video: + input_video_ = pipe.preprocess_video(input_video_) + input_latent_ = pipe.vae.encode(input_video_, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents.append(input_latent_) + input_latents = torch.cat(input_latents, dim = 0) ### B C F H W + # if vace_reference_image is not None: + # vace_reference_image = pipe.preprocess_video([vace_reference_image]) + # vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + # input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + # pipe.load_models_to_device(self.onload_model_names) + pipe.text_encoder = pipe.text_encoder.to(pipe.device) + prompt_emb_list = [] + for prompt_ in prompt: + prompt_emb_ = pipe.prompter.encode_prompt(prompt_, positive=positive, device=pipe.device) ###B C Token + prompt_emb_list.append(prompt_emb_) + prompt_emb = torch.cat(prompt_emb_list, dim = 0) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedder(PipelineUnit): + """ + Deprecated + """ + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or pipe.image_encoder is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context, "y": y} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + +class WanVideoUnit_VideoEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_pre_video", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_pre_video, latents, height, width, tiled, tile_size, tile_stride): + if input_pre_video is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + # image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + # z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + input_pre_video = pipe.preprocess_video(input_pre_video) + input_pre_video_latent = pipe.vae.encode(input_pre_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + pre_t_num = input_pre_video_latent.shape[2] + latents[:, :, :pre_t_num] = input_pre_video_latent + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "prev_video_latents": input_pre_video_latent} + + +class WanVideoUnit_RefEmbedderFused(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("ref_images", "latents", "height", "width", "tiled", "tile_size", "tile_stride", "num_ref_images"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, ref_images, latents, height, width, tiled, tile_size, tile_stride, num_ref_images): + if ref_images is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + ref_images_latents = [] + for ref_images_ in ref_images: + ref_images_latent_ = pipe.extrac_ref_latents(ref_images_, pipe.vae, device=pipe.device, dtype=pipe.torch_dtype)[0][None] + ref_images_latents.append(ref_images_latent_) ##1 C ref_num H W + ref_images_latents = torch.concat(ref_images_latents, dim=0) + # r = num_ref_images - ref_images_latents.shape[2] + # ref_images_latents = F.pad(ref_images_latents, (0, 0, 0, 0, 0, r)) + latents = torch.concat([latents, ref_images_latents], dim=2) + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "ref_images_latents": ref_images_latents} + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(["vae"]) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__(input_params=("motion_bucket_id",)) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + vace_reference_image = pipe.preprocess_video([vace_reference_image]) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=()) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + +class WanVideoUnit_ShotEmbedder(PipelineUnit): + def __init__(self): + super().__init__(input_params=("shot_cut_frames", "num_frames")) + + def process(self, pipe: WanVideoHoloCinePipeline, shot_cut_frames, num_frames): + if shot_cut_frames is None: + return {} + + num_latent_frames = (num_frames - 1) // 4 + 1 + + # Convert frame cut indices to latent cut indices + shot_cut_latents = [0] + for frame_idx in sorted(shot_cut_frames): + if frame_idx > 0: + latent_idx = (frame_idx - 1) // 4 + 1 + if latent_idx < num_latent_frames: + shot_cut_latents.append(latent_idx) + + cuts = sorted(list(set(shot_cut_latents))) + [num_latent_frames] + + + shot_indices = torch.zeros(num_latent_frames, dtype=torch.long) + for i in range(len(cuts) - 1): + start_latent, end_latent = cuts[i], cuts[i+1] + shot_indices[start_latent:end_latent] = i + + shot_indices = shot_indices.unsqueeze(0).to(device=pipe.device) + + return {"shot_indices": shot_indices} + + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + args, + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + audio_input: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + pose_cond: Optional[torch.Tensor] = None, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + num_ref_images=None, + prev_video_latents: Optional[torch.Tensor] = None, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.ones((latents.shape[2] - num_ref_images, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep, + torch.zeros((num_ref_images, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Add camera control + x, (f, h, w) = dit.patchify(x, control_camera_latents_input) + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + + + if args.shot_rope: + device = dit.shot_freqs[0].device + freq_s, freq_f, freq_h, freq_w = dit.shot_freqs # (end, dim_*/2) complex + shots_nums_batch = [ + [20, 20, 20, 3, 3], + [20, 20, 20, 3, 3], + ] + batch_freqs = [] # ⭐ 每个 sample 一个 freqs + + for shots_nums in shots_nums_batch: # loop over batch + sample_freqs = [] # 当前 sample 的所有 shot freqs + for shot_index, num_frames in enumerate(shots_nums): + f = num_frames + rope_s = freq_s[shot_index] \ + .view(1, 1, 1, -1) \ + .expand(f, h, w, -1) + + rope_f = freq_f[:f] \ + .view(f, 1, 1, -1) \ + .expand(f, h, w, -1) + + rope_h = freq_h[:h] \ + .view(1, h, 1, -1) \ + .expand(f, h, w, -1) + + rope_w = freq_w[:w] \ + .view(1, 1, w, -1) \ + .expand(f, h, w, -1) + + freqs = torch.cat( + [rope_s, rope_f, rope_h, rope_w], + dim=-1 + ) # (f, h, w, dim/2) complex + + freqs = freqs.reshape(f * h * w, 1, -1) + sample_freqs.append(freqs) + + # 拼一个 sample 内所有 shot + sample_freqs = torch.cat(sample_freqs, dim=0) # (N, 1, dim/2) + batch_freqs.append(sample_freqs) + + # ⭐ stack 成 batch + batch_freqs = torch.stack(batch_freqs, dim=0).to(x.device) + # shape: (B, N, 1, dim/2) + + + if args.split_rope: + device = dit.freqs[0].device + freq_f, freq_h, freq_w = dit.freqs # 预先计算好的 1D rope freqs + # ============================== + # 1) Video 的 RoPE 位置 + # ============================== + f_video = torch.arange(f - num_ref_images, device=device) + h_video = torch.arange(h, device=device) + w_video = torch.arange(w, device=device) + + rope_f_video = freq_f[f_video].view(f - num_ref_images, 1, 1, -1).expand(f - num_ref_images, h, w, -1) + rope_h_video = freq_h[h_video].view(1, h, 1, -1).expand(f - num_ref_images, h, w, -1) + rope_w_video = freq_w[w_video].view(1, 1, w, -1).expand(f - num_ref_images, h, w, -1) + + rope_video = torch.cat([rope_f_video, rope_h_video, rope_w_video], dim=-1) + rope_video = rope_video.reshape((f - num_ref_images) * h * w, 1, -1).to(x.device) + + # ============================== + # 2) Reference Images 的 RoPE 位置(全部偏移) + # ============================== + # f 维: ref 占用 [offset ... offset + num_ref_images - 1] + + offset=f - num_ref_images + 10 + if args.split1: + # method 1: f h w 全 offset + f_ref = torch.arange(num_ref_images, device=device) + offset + # h/w 全部偏移 offset + h_ref = torch.arange(h, device=device) + offset + w_ref = torch.arange(w, device=device) + offset + elif args.split2: + # method 2: f offset + f_ref = torch.arange(num_ref_images, device=device) + offset + # h/w 全部偏移 offset + h_ref = torch.arange(h, device=device) + w_ref = torch.arange(w, device=device) + + elif args.split3: + # method 3: f offset but same h w offset + f_ref = torch.tensor([0, 0, 0], device=device) + offset + # h/w 全部偏移 offset + h_ref = torch.arange(h, device=device) + offset + w_ref = torch.arange(w, device=device) + offset + + + rope_f_ref = freq_f[f_ref].view(num_ref_images, 1, 1, -1).expand(num_ref_images, h, w, -1) + rope_h_ref = freq_h[h_ref].view(1, h, 1, -1).expand(num_ref_images, h, w, -1) + rope_w_ref = freq_w[w_ref].view(1, 1, w, -1).expand(num_ref_images, h, w, -1) + + rope_ref = torch.cat([rope_f_ref, rope_h_ref, rope_w_ref], dim=-1) + rope_ref = rope_ref.reshape(num_ref_images * h * w, 1, -1).to(x.device) + + # ============================== + # 3) 拼接 video + ref-image + # ============================== + freqs = torch.cat([rope_video, rope_ref], dim=0) + + + + else: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace(x, vace_context, context, t_mod, freqs) + + ## 构造一个 attention mask,使得每个 video token 只能 attend 自己所属 shot 的 text tokens,其它全部强制屏蔽。在 cross attention 过程中 + use_attn_mask = True + if use_attn_mask: + shot_ranges = [ + (s0, e0), # shot 0 的 text + (s1, e1), # shot 1 的 text + ] + try: + B, S_q = x.shape[0], x.shape[1] + L_text_ctx = context.shape[1] + + shot_ranges = text_cut_positions['shots'] + S_shots = len(shot_ranges) + + device, dtype = x.device, x.dtype + + # -------------------------------------------------- + # 1. 构建 shot_table: (S_shots, L_text_ctx) + # -------------------------------------------------- + shot_table = torch.zeros( + S_shots, L_text_ctx, + dtype=torch.bool, + device=device + ) + + for sid, (s0, s1) in enumerate(shot_ranges): + s0 = int(s0) + s1 = int(s1) + shot_table[sid, s0:s1 + 1] = True + + # -------------------------------------------------- + # 2. video token -> shot id + # shot_indices: (B, T) + # expand to (B, T*h*w) = (B, S_q) + # shot_indices 是表示每个video token 属于哪一个shot 的索引 + # -------------------------------------------------- + vid_shot = shot_indices.repeat_interleave(h * w, dim=1) + + # sanity check(强烈建议保留) + max_shot_id = int(vid_shot.max()) + assert max_shot_id < S_shots, \ + f"shot index out of bounds: max={max_shot_id}, S_shots={S_shots}" + + # -------------------------------------------------- + # 3. allow mask: (B, S_q, L_text_ctx) + # -------------------------------------------------- + allow = shot_table[vid_shot] + + # -------------------------------------------------- + # 4. 构建 attention bias + # -------------------------------------------------- + block_value = -1e4 + bias = torch.zeros( + B, S_q, L_text_ctx, + dtype=dtype, + device=device + ) + bias = bias.masked_fill(~allow, block_value) + + # attn_mask shape: (B, 1, S_q, L_text_ctx) + attn_mask = bias.unsqueeze(1) + + except Exception as e: + print("!!!!!! ERROR FOUND IN SHOT ATTENTION MASK !!!!!!!") + raise e + else: + attn_mask = None + + use_sparse_self_attn = getattr(dit, 'use_sparse_self_attn', False) + if use_sparse_self_attn: + shot_latent_indices = shot_indices.repeat_interleave(h * w, dim=1) + shot_latent_indices = labels_to_cuts(shot_latent_indices) + else: + shot_latent_indices = None + + + + + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) + x = x + current_vace_hint * vace_scale + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x + + + def labels_to_cuts(batch_labels: torch.Tensor): + + assert batch_labels.dim() == 2, "expect [b, s]" + b, s = batch_labels.shape + labs = batch_labels.to(torch.long) + + + diffs = torch.zeros((b, s), dtype=torch.bool, device=labs.device) + diffs[:, 1:] = labs[:, 1:] != labs[:, :-1] + + cuts_list = [] + for i in range(b): + + change_pos = torch.nonzero(diffs[i], as_tuple=False).flatten() + cuts = [0] + cuts.extend(change_pos.tolist()) + if cuts[-1] != s: + cuts.append(s) + + cuts_list.append(cuts) + return cuts_list diff --git a/diffsynth/pipelines/wan_video_statemachine copy.py b/diffsynth/pipelines/wan_video_statemachine copy.py new file mode 100644 index 0000000000000000000000000000000000000000..ea03f0f7ced2013c1c6d158d3c1ea24f236d1010 --- /dev/null +++ b/diffsynth/pipelines/wan_video_statemachine copy.py @@ -0,0 +1,1595 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import Wav2Vec2Processor + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d +from ..models.wan_video_dit_s2v import rope_precompute +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel +from ..models.wav2vec import WanS2VAudioEncoder +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.audio_processor: Wav2Vec2Processor = None + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None + self.vap: MotWanModel = None + self.animate_adapter: WanAnimateAdapter = None + self.audio_encoder: WanS2VAudioEncoder = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_AnimateVideoSplit(), + WanVideoUnit_AnimatePoseLatents(), + WanVideoUnit_AnimateFacePixelValues(), + WanVideoUnit_AnimateInpaint(), + WanVideoUnit_VAP(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), + ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] + self.model_fn = model_fn_wan_video + + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + use_siglip_image_encoder: bool = False, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), + "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), + "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] + model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp() + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit + pipe.vae = model_pool.fetch_model("wan_video_vae") + if use_siglip_image_encoder: + pipe.image_encoder = model_pool.fetch_model("siglip2_image_encoder") + else: + pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") + vace = model_pool.fetch_model("wan_video_vace", index=2) + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace + pipe.vap = model_pool.fetch_model("wan_video_vap") + pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + if audio_processor_config is not None: + audio_processor_config.download_if_necessary() + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Instance StateMachine (optional) + # These tensors are forwarded to the DiT model_fn if supported by the loaded DiT. + # Expected shapes: + # instance_class_ids: (B, N) long + # instance_state_ids: (B, N) long + # instance_ids: (B, N) long + # instance_masks: (B, N, T, H, W) float/bool (pixel/latent space), will be downsampled to patch grid + instance_class_ids: Optional[torch.Tensor] = None, + instance_state_ids: Optional[torch.Tensor] = None, + instance_ids: Optional[torch.Tensor] = None, + instance_masks: Optional[torch.Tensor] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "vap_prompt": vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "instance_class_ids": instance_class_ids, + "instance_state_ids": instance_state_ids, + "instance_ids": instance_ids, + "instance_masks": instance_masks, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + models["vace"] = self.vace2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # VACE (TODO: remove it) + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): + if vace_reference_image is not None and isinstance(vace_reference_image, list): + f = len(vace_reference_image) + else: + f = 1 + inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if vace_reference_image is not None: + if not isinstance(vace_reference_image, list): + vace_reference_image = [vace_reference_image] + vace_reference_image = pipe.preprocess_video(vace_reference_image) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), + output_params=("clip_feature", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae", "image_encoder") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("motion_bucket_id",), + output_params=("motion_bucket_id",) + ) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + output_params=("vace_context", "vace_scale"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + + vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder", "vae", "image_encoder"), + input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") + vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) + negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae", "image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + output_params=("tea_cache",) + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae",), + input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), + output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} + pipe.load_models_to_device(["audio_encoder"]) + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): + pipe.load_models_to_device(["vae"]) + motion_frames = 73 + kwargs = {} + if motion_video is not None and len(motion_video) > 0: + assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" + motion_latents = pipe.preprocess_video(motion_video) + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + kwargs.update({"motion_latents": motion_latents}) + return kwargs + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} + if s2v_pose_video is None: + return {"s2v_pose_latents": None} + pipe.load_models_to_device(["vae"]) + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] + # pad if not enough frames + padding_frames = infer_frames * num_repeats - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) + return inputs_shared, inputs_posi, inputs_nega + + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + + +class WanVideoUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), + output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") + ) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + output_params=("pose_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("animate_face_video",), + output_params=("face_pixel_values"), + ) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return inputs_shared, inputs_posi, inputs_nega + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + output_params=("longcat_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + vap: MotWanModel = None, + animate_adapter: WanAnimateAdapter = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + audio_embeds: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, + drop_motion_frames: bool = True, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, + longcat_latents=None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + # Instance StateMachine (optional) + instance_class_ids: Optional[torch.Tensor] = None, + instance_state_ids: Optional[torch.Tensor] = None, + instance_ids: Optional[torch.Tensor] = None, + instance_masks: Optional[torch.Tensor] = None, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + instance_class_ids=instance_class_ids, + instance_state_ids=instance_state_ids, + instance_ids=instance_ids, + instance_masks=instance_masks, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y", "instance_masks"], + batch_size=2 if cfg_merge else 1 + ) + # LongCat-Video + if isinstance(dit, LongCatVideoTransformer3DModel): + return model_fn_longcat_video( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + longcat_latents=longcat_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + # wan2.2 s2v + if audio_embeds is not None: + return model_fn_wans2v( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + audio_embeds=audio_embeds, + motion_latents=motion_latents, + s2v_pose_latents=s2v_pose_latents, + drop_motion_frames=drop_motion_frames, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + if instance_class_ids is not None and instance_class_ids.shape[0] != context.shape[0]: + instance_class_ids = torch.concat([instance_class_ids] * context.shape[0], dim=0) + if instance_state_ids is not None and instance_state_ids.shape[0] != context.shape[0]: + instance_state_ids = torch.concat([instance_state_ids] * context.shape[0], dim=0) + if instance_ids is not None and instance_ids.shape[0] != context.shape[0]: + instance_ids = torch.concat([instance_ids] * context.shape[0], dim=0) + if instance_masks is not None and instance_masks.shape[0] != context.shape[0]: + instance_masks = torch.concat([instance_masks] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Camera control + patchified = dit.patchify(x, control_camera_latents_input) + motion_vec = None + + # Animate + # patchified could be: + # - legacy: (B, C, F, H, W) + # - statemachine_1: (tokens, grid_size) where tokens is (B, L, C) and grid_size is (F, H, W) + if isinstance(patchified, (tuple, list)) and len(patchified) == 2: + x, grid_size = patchified + f, h, w = grid_size + else: + x = patchified + if pose_latents is not None and face_pixel_values is not None: + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + f, h, w = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + + # Instance StateMachine: build instance tokens and mask in patch-token space (B, N, L) + inst_tokens, inst_mask_flat = None, None + if instance_class_ids is not None and instance_state_ids is not None and instance_ids is not None and instance_masks is not None: + if hasattr(dit, "instance_encoder"): + inst_tokens = dit.instance_encoder(instance_class_ids, instance_state_ids, instance_ids) + if hasattr(dit, "process_masks"): + inst_mask_flat = dit.process_masks(instance_masks, (f, h, w)) + else: + # Fallback: downsample to (f,h,w) and flatten + inst_down = torch.nn.functional.interpolate(instance_masks.to(dtype=torch.float32), size=(f, h, w), mode="nearest") + inst_mask_flat = rearrange(inst_down, "b n f h w -> b n (f h w)").clamp_(0.0, 1.0) + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + # Pad instance masks for the reference token prefix so shapes match (B, N, L_total) + if inst_mask_flat is not None: + ref_len = reference_latents.shape[1] + inst_mask_flat = torch.cat( + [torch.zeros((inst_mask_flat.shape[0], inst_mask_flat.shape[1], ref_len), device=inst_mask_flat.device, dtype=inst_mask_flat.dtype), inst_mask_flat], + dim=2, + ) + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # VAP + if vap is not None: + # hidden state + x_vap = vap_hidden_state + x_vap = vap.patchify(x_vap) + x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() + # Timestep + clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) + t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) + t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) + + # rope + freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) + + # context + vap_clip_embedding = vap.img_emb(vap_clip_feature) + context_vap = vap.text_embedding(context_vap) + context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + if inst_mask_flat is not None: + mask_chunks = torch.chunk(inst_mask_flat, get_sequence_parallel_world_size(), dim=2) + mask_chunks = [ + torch.nn.functional.pad(chunk, (0, mask_chunks[0].shape[2] - chunk.shape[2]), value=0) + for chunk in mask_chunks + ] + inst_mask_flat = mask_chunks[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_with_instance(module, instance_tokens, instance_masks): + def custom_forward(*inputs): + return module(*inputs, instance_tokens=instance_tokens, instance_masks=instance_masks) + return custom_forward + + def create_custom_forward_vap(block, vap): + def custom_forward(*inputs): + return vap(block, *inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + # Block + if vap is not None and block_id in vap.mot_layers_mapping: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + else: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + else: + use_instance = inst_tokens is not None and inst_mask_flat is not None + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_with_instance(block, inst_tokens, inst_mask_flat) if use_instance else create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_with_instance(block, inst_tokens, inst_mask_flat) if use_instance else create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if inst_tokens is not None and inst_mask_flat is not None: + x = block(x, context, t_mod, freqs, instance_tokens=inst_tokens, instance_masks=inst_mask_flat) + else: + x = block(x, context, t_mod, freqs) + + # VACE + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) + x = x + current_vace_hint * vace_scale + + # Animate + if motion_vec is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x + + +def model_fn_longcat_video( + dit: LongCatVideoTransformer3DModel, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + longcat_latents: torch.Tensor = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, +): + if longcat_latents is not None: + latents[:, :, :longcat_latents.shape[2]] = longcat_latents + num_cond_latents = longcat_latents.shape[2] + else: + num_cond_latents = 0 + context = context.unsqueeze(0) + encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) + output = dit( + latents, + timestep, + context, + encoder_attention_mask, + num_cond_latents=num_cond_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + output = -output + output = output.to(latents.dtype) + return output + + +def model_fn_wans2v( + dit, + latents, + timestep, + context, + audio_embeds, + motion_latents, + s2v_pose_latents, + drop_motion_frames=True, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, +): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) + + # x and s2v_pose_latents + s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel + + # reference image + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) + + x = x + dit.trainable_cond_mask(mask).to(x.dtype) + + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] + x = dit.head(x, t[:-1]) + x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/diffsynth/pipelines/wan_video_statemachine.py b/diffsynth/pipelines/wan_video_statemachine.py new file mode 100644 index 0000000000000000000000000000000000000000..e310692464a168dc67c27c68d0a137c81de89f74 --- /dev/null +++ b/diffsynth/pipelines/wan_video_statemachine.py @@ -0,0 +1,1617 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import Wav2Vec2Processor + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d +from ..models.wan_video_dit_s2v import rope_precompute +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel +from ..models.wav2vec import WanS2VAudioEncoder +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.audio_processor: Wav2Vec2Processor = None + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None + self.vap: MotWanModel = None + self.animate_adapter: WanAnimateAdapter = None + self.audio_encoder: WanS2VAudioEncoder = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_InstanceStateTextEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_AnimateVideoSplit(), + WanVideoUnit_AnimatePoseLatents(), + WanVideoUnit_AnimateFacePixelValues(), + WanVideoUnit_AnimateInpaint(), + WanVideoUnit_VAP(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), + ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] + self.model_fn = model_fn_wan_video + + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + use_siglip_image_encoder: bool = False, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), + "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), + "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] + model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp() + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit + pipe.vae = model_pool.fetch_model("wan_video_vae") + if use_siglip_image_encoder: + pipe.image_encoder = model_pool.fetch_model("siglip2_image_encoder") + else: + pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") + vace = model_pool.fetch_model("wan_video_vace", index=2) + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace + pipe.vap = model_pool.fetch_model("wan_video_vap") + pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + if audio_processor_config is not None: + audio_processor_config.download_if_necessary() + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Instance (optional, bbox-based only) + # instance_ids: Tensor (B, N) or (N,) + # instance_class_text: List[str] or str (tag per instance) + # instance_state_texts: List[List[str]] or List[str] (state labels per instance; each instance has S states) + # instance_state_weights: Tensor/list (B, N, F, S) or (N, F, S) weights per frame + # instance_bboxes: Tensor (B, N, F, 4) or (N, F, 4) xyxy in pixel coords + instance_ids: Optional[torch.Tensor] = None, + instance_class_text: Optional[list[str] | str] = None, + instance_state_texts: Optional[list[list[str]] | list[str]] = None, + instance_state_weights: Optional[torch.Tensor | list] = None, + instance_bboxes: Optional[torch.Tensor] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + ): + def _to_tensor(x, dtype=None): + if x is None: + return None + if not isinstance(x, torch.Tensor): + x = torch.as_tensor(x) + if dtype is not None: + x = x.to(dtype=dtype) + return x.to(device=self.device) + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + inst_bboxes = _to_tensor(instance_bboxes, dtype=torch.float32) + if inst_bboxes is not None and inst_bboxes.ndim == 3: + inst_bboxes = inst_bboxes.unsqueeze(0) + if inst_bboxes is not None and inst_bboxes.ndim == 2 and inst_bboxes.shape[-1] == 4: + inst_bboxes = inst_bboxes.unsqueeze(0) + inst_ids = _to_tensor(instance_ids, dtype=torch.long) + if inst_ids is not None and inst_ids.ndim == 1: + inst_ids = inst_ids.unsqueeze(0) + inst_state_weights = _to_tensor(instance_state_weights, dtype=torch.float32) + if inst_state_weights is not None: + if inst_state_weights.ndim == 3: + inst_state_weights = inst_state_weights.unsqueeze(0) + elif inst_state_weights.ndim == 2: + inst_state_weights = inst_state_weights.unsqueeze(0).unsqueeze(0) + + use_instance = any( + v is not None + for v in (inst_ids, instance_class_text, instance_state_texts, inst_state_weights, inst_bboxes) + ) + if use_instance: + if inst_ids is None or inst_bboxes is None or instance_class_text is None or instance_state_texts is None or inst_state_weights is None: + raise ValueError( + "When using instance control, please provide: instance_ids, instance_class_text, instance_state_texts, instance_state_weights, instance_bboxes." + ) + if inst_state_weights.shape[2] != inst_bboxes.shape[2]: + raise ValueError( + f"instance_state_weights and instance_bboxes must have the same frame length on dim=2, got {int(inst_state_weights.shape[2])} vs {int(inst_bboxes.shape[2])}." + ) + + # Inputs + inputs_posi = { + "prompt": prompt, + "vap_prompt": vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "instance_ids": inst_ids, + "instance_bboxes": inst_bboxes, + "instance_state_texts": instance_state_texts, + "instance_class_text": instance_class_text, + "instance_state_weights": inst_state_weights, + "instance_state_text_embeds_multi": None, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Keep model_fn inputs minimal for instance control + inputs_shared.pop("instance_class_text", None) + inputs_shared.pop("instance_state_texts", None) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + models["vace"] = self.vace2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # VACE (TODO: remove it) + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): + if vace_reference_image is not None and isinstance(vace_reference_image, list): + f = len(vace_reference_image) + else: + f = 1 + inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if vace_reference_image is not None: + if not isinstance(vace_reference_image, list): + vace_reference_image = [vace_reference_image] + vace_reference_image = pipe.preprocess_video(vace_reference_image) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + + +class WanVideoUnit_InstanceStateTextEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=( + "instance_ids", + "instance_class_text", + "instance_state_texts", + "instance_state_weights", + "instance_bboxes", + ), + output_params=("instance_state_text_embeds_multi",), + onload_model_names=("text_encoder",), + ) + + @staticmethod + def _as_list(x): + if x is None: + return None + if isinstance(x, list): + return x + return [x] + + @staticmethod + def _as_nested_list(x): + if x is None: + return None + if isinstance(x, (str, bytes)): + return [[x]] + if not isinstance(x, list): + return [[x]] + if len(x) == 0: + return [[]] + if isinstance(x[0], (list, tuple)): + return [list(v) for v in x] + return [list(x)] + + def _encode_phrases(self, pipe: WanVideoPipeline, phrases: list[str]) -> torch.Tensor: + if pipe.tokenizer is None or pipe.text_encoder is None: + raise ValueError("tokenizer/text_encoder is not initialized; cannot encode instance state texts.") + pipe.load_models_to_device(self.onload_model_names) + ids, mask = pipe.tokenizer(phrases, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + text_tokens = pipe.text_encoder(ids, mask) # (N, L, text_dim) + for i, v in enumerate(seq_lens): + text_tokens[i, v:] = 0 + denom = mask.to(dtype=text_tokens.dtype).sum(dim=1, keepdim=True).clamp(min=1) + pooled = (text_tokens * mask.to(dtype=text_tokens.dtype).unsqueeze(-1)).sum(dim=1) / denom + return pooled.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) # (1, N, text_dim) + + def process( + self, + pipe: WanVideoPipeline, + instance_ids, + instance_class_text, + instance_state_texts, + instance_state_weights, + instance_bboxes, + ) -> dict: + class_texts = self._as_list(instance_class_text) + multi_state_texts = self._as_nested_list(instance_state_texts) + if class_texts is None and multi_state_texts is None: + return {} + + if class_texts is None: + raise ValueError("instance_class_text is required when using instance control.") + if multi_state_texts is None: + raise ValueError("instance_state_texts is required when using instance control.") + if instance_ids is None or instance_bboxes is None: + raise ValueError("instance_ids and instance_bboxes are required for instance control.") + + if isinstance(instance_ids, torch.Tensor): + if instance_ids.ndim != 2: + raise ValueError(f"instance_ids must be (B,N), got {tuple(instance_ids.shape)}") + num_instances = int(instance_ids.shape[1]) + if len(class_texts) != num_instances: + raise ValueError(f"instance_class_text length {len(class_texts)} != N {num_instances} from instance_ids") + if isinstance(instance_bboxes, torch.Tensor) and instance_bboxes.ndim == 4 and int(instance_bboxes.shape[1]) != num_instances: + raise ValueError(f"instance_bboxes N {int(instance_bboxes.shape[1])} != N {num_instances} from instance_ids") + + if len(multi_state_texts) != len(class_texts): + raise ValueError(f"instance_state_texts length {len(multi_state_texts)} != instance_class_text length {len(class_texts)}") + if len(multi_state_texts) == 0: + return {} + + state_counts = [len(states) for states in multi_state_texts] + if len(set(state_counts)) != 1: + raise ValueError("Each instance must have the same number of states in instance_state_texts.") + num_states = state_counts[0] + if num_states == 0: + return {} + + if instance_state_weights is not None: + # Expect (B,N,F,S) or (N,F,S)/(B,N,S); enforce last dim matches S. + s_dim = int(getattr(instance_state_weights, "shape", [None])[-1]) + if s_dim != num_states: + raise ValueError(f"instance_state_weights last dim {s_dim} != num_states {num_states}") + + phrases = [] + for cls, states in zip(class_texts, multi_state_texts): + phrases.extend([f"{cls} is {s}" for s in states]) + embeds_flat = self._encode_phrases(pipe, phrases) # (1, N*S, text_dim) + embeds_multi = embeds_flat.view(1, len(class_texts), num_states, embeds_flat.shape[-1]) + return {"instance_state_text_embeds_multi": embeds_multi} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), + output_params=("clip_feature", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae", "image_encoder") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("motion_bucket_id",), + output_params=("motion_bucket_id",) + ) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + output_params=("vace_context", "vace_scale"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + + vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder", "vae", "image_encoder"), + input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") + vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) + negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae", "image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + output_params=("tea_cache",) + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae",), + input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), + output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} + pipe.load_models_to_device(["audio_encoder"]) + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): + pipe.load_models_to_device(["vae"]) + motion_frames = 73 + kwargs = {} + if motion_video is not None and len(motion_video) > 0: + assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}" + motion_latents = pipe.preprocess_video(motion_video) + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + kwargs.update({"motion_latents": motion_latents}) + return kwargs + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} + if s2v_pose_video is None: + return {"s2v_pose_latents": None} + pipe.load_models_to_device(["vae"]) + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] + # pad if not enough frames + padding_frames = infer_frames * num_repeats - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) + return inputs_shared, inputs_posi, inputs_nega + + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + + +class WanVideoUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), + output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") + ) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + output_params=("pose_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("animate_face_video",), + output_params=("face_pixel_values"), + ) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return inputs_shared, inputs_posi, inputs_nega + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + output_params=("longcat_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + sliced = {} + for tensor_name in tensor_names: + val = tensor_dict[tensor_name] + if val.ndim == 5: + # (B,C,T,H,W) + sliced_val = val[:, :, t:t_, :, :] + elif val.ndim == 4: + # (B,N,T,D) e.g., instance_bboxes/weights + sliced_val = val[:, :, t:t_, :] + elif val.ndim == 3: + # (B,N,T) + sliced_val = val[:, :, t:t_] + else: + raise ValueError(f"TemporalTiler_BCTHW only supports 3D/4D/5D tensors with time dim, got {tensor_name} shape {tuple(val.shape)}") + sliced[tensor_name] = sliced_val.to(device=computation_device, dtype=computation_dtype) + model_kwargs.update(sliced) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + vap: MotWanModel = None, + animate_adapter: WanAnimateAdapter = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + audio_embeds: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, + drop_motion_frames: bool = True, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, + longcat_latents=None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + # Instance (optional, bbox-based only) + instance_ids: Optional[torch.Tensor] = None, # (B,N) + instance_state_text_embeds_multi: Optional[torch.Tensor] = None, # (B,N,S,text_dim) + instance_state_weights: Optional[torch.Tensor] = None, # (B,N,F,S) weights per frame + instance_bboxes: Optional[torch.Tensor] = None, # (B,N,F,4) + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + if any(v is not None for v in (instance_ids, instance_state_text_embeds_multi, instance_state_weights, instance_bboxes)): + raise ValueError("sliding_window_size/stride is not supported together with instance control (bbox/state weights).") + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + vap=vap, + animate_adapter=animate_adapter, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + instance_ids=instance_ids, + instance_state_text_embeds_multi=instance_state_text_embeds_multi, + instance_state_weights=instance_state_weights, + instance_bboxes=instance_bboxes, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + # Instance (align batch for CFG) + if instance_ids is not None and instance_ids.shape[0] != context.shape[0]: + instance_ids = torch.concat([instance_ids] * context.shape[0], dim=0) + if instance_bboxes is not None and instance_bboxes.shape[0] != context.shape[0]: + instance_bboxes = torch.concat([instance_bboxes] * context.shape[0], dim=0) + if instance_state_text_embeds_multi is not None and instance_state_text_embeds_multi.shape[0] != context.shape[0]: + instance_state_text_embeds_multi = torch.concat([instance_state_text_embeds_multi] * context.shape[0], dim=0) + if instance_state_weights is not None and instance_state_weights.shape[0] != context.shape[0]: + instance_state_weights = torch.concat([instance_state_weights] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Camera control + patchified = dit.patchify(x, control_camera_latents_input) + motion_vec = None + + # Animate + # patchified could be: + # - legacy: (B, C, F, H, W) + # - statemachine_1: (tokens, grid_size) where tokens is (B, L, C) and grid_size is (F, H, W) + if isinstance(patchified, (tuple, list)) and len(patchified) == 2: + x, grid_size = patchified + f, h, w = grid_size + else: + x = patchified + if pose_latents is not None and face_pixel_values is not None: + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + f, h, w = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + + # Instance: build instance tokens and bbox->patch mask in patch-token space (B, N, L) + inst_tokens, inst_mask_flat = None, None + orig_frames, orig_h, orig_w = latents.shape[2:] + if ( + instance_ids is not None + and instance_bboxes is not None + and instance_state_text_embeds_multi is not None + and instance_state_weights is not None + and hasattr(dit, "instance_encoder") + ): + inst_tokens = dit.instance_encoder( + instance_ids=instance_ids, + state_text_embeds_multi=instance_state_text_embeds_multi, + state_weights=instance_state_weights, + num_time_patches=f, + ) + H_img = int(kwargs.get("height", orig_h)) + W_img = int(kwargs.get("width", orig_w)) + F_img = int(instance_bboxes.shape[2]) if instance_bboxes.ndim == 4 else int(orig_frames) + inst_mask_flat = dit.process_masks( + grid_size=(f, h, w), + image_size=(F_img, H_img, W_img), + bboxes=instance_bboxes, + bbox_mask=None, + ) + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + if inst_mask_flat is not None: + ref_len = reference_latents.shape[1] + inst_mask_flat = torch.cat( + [torch.zeros((inst_mask_flat.shape[0], inst_mask_flat.shape[1], ref_len), device=inst_mask_flat.device, dtype=inst_mask_flat.dtype), inst_mask_flat], + dim=2, + ) + if inst_tokens is not None: + inst_tokens = torch.cat( + [ + torch.zeros((inst_tokens.shape[0], 1, inst_tokens.shape[2], inst_tokens.shape[3]), device=inst_tokens.device, dtype=inst_tokens.dtype), + inst_tokens, + ], + dim=1, + ) + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + full_seq_len = x.shape[1] + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + if inst_mask_flat is not None: + mask_chunks = torch.chunk(inst_mask_flat, get_sequence_parallel_world_size(), dim=2) + mask_chunks = [ + torch.nn.functional.pad(chunk, (0, mask_chunks[0].shape[2] - chunk.shape[2]), value=0) + for chunk in mask_chunks + ] + inst_mask_flat = mask_chunks[get_sequence_parallel_rank()] + if inst_tokens is not None: + chunk_len = x.shape[1] + hw = int(h * w) + offset = int(get_sequence_parallel_rank() * chunk_len) + global_pos = torch.arange(chunk_len, device=x.device) + offset + valid = global_pos < full_seq_len + time_index = (global_pos // max(hw, 1)).clamp(max=inst_tokens.shape[1] - 1).long() + inst_tokens = inst_tokens[:, time_index] * valid.view(1, chunk_len, 1, 1).to(dtype=inst_tokens.dtype) + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_with_instance(module, instance_tokens, instance_masks): + def custom_forward(*inputs): + return module(*inputs, instance_tokens=instance_tokens, instance_masks=instance_masks) + return custom_forward + + def create_custom_forward_vap(block, vap): + def custom_forward(*inputs): + return vap(block, *inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + # Block + if vap is not None and block_id in vap.mot_layers_mapping: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + else: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + else: + use_instance = inst_tokens is not None and inst_mask_flat is not None + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_with_instance(block, inst_tokens, inst_mask_flat) if use_instance else create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_with_instance(block, inst_tokens, inst_mask_flat) if use_instance else create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + if inst_tokens is not None and inst_mask_flat is not None: + x = block(x, context, t_mod, freqs, instance_tokens=inst_tokens, instance_masks=inst_mask_flat) + else: + x = block(x, context, t_mod, freqs) + + # VACE + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) + x = x + current_vace_hint * vace_scale + + # Animate + if motion_vec is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f87254f35015f87924c64ac5241f257f6b150995 --- /dev/null +++ b/diffsynth/pipelines/z_image.py @@ -0,0 +1,257 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from typing import Union, List, Optional, Tuple + +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput + +from transformers import AutoTokenizer +from ..models.z_image_text_encoder import ZImageTextEncoder +from ..models.z_image_dit import ZImageDiT +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder + + +class ZImagePipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("Z-Image") + self.text_encoder: ZImageTextEncoder = None + self.dit: ZImageDiT = None + self.vae_encoder: FluxVAEEncoder = None + self.vae_decoder: FluxVAEDecoder = None + self.tokenizer: AutoTokenizer = None + self.in_iteration_models = ("dit",) + self.units = [ + ZImageUnit_ShapeChecker(), + ZImageUnit_PromptEmbedder(), + ZImageUnit_NoiseInitializer(), + ZImageUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_z_image + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("z_image_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 8, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class ZImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: ZImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class ZImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt( + self, + pipe, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = pipe.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def process(self, pipe: ZImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) + return {"prompt_embeds": prompt_embeds} + + +class ZImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: ZImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class ZImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae_encoder(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_z_image( + dit: ZImageDiT, + latents=None, + timestep=None, + prompt_embeds=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + latents = [rearrange(latents, "B C H W -> C B H W")] + timestep = (1000 - timestep) / 1000 + model_output = dit( + latents, + timestep, + prompt_embeds, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + )[0][0] + model_output = -model_output + model_output = rearrange(model_output, "C B H W -> B C H W") + return model_output diff --git a/diffsynth/utils/blip2_state_text_encoder.py b/diffsynth/utils/blip2_state_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..18d66f90161529778f7df47a446c6a80e7a76704 --- /dev/null +++ b/diffsynth/utils/blip2_state_text_encoder.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import lru_cache +from typing import Iterable, List + +import torch + + +@dataclass +class Blip2TextStateEncoderConfig: + model_name_or_path: str = "Salesforce/blip2-itm-vit-g" + device: str = "cpu" + torch_dtype: torch.dtype = torch.float16 + max_length: int = 32 + + +class Blip2TextStateEncoder: + """ + 用 BLIP2 的 `Blip2TextModelWithProjection` 把状态文本编码为一个向量(text_embeds)。 + + 设计目标: + - 状态在数据里用可读字符串(例如 "raw", "cooked") + - 训练/推理阶段把这些字符串变成 state_features: (B,N,D_text) + - 下游 InstanceFeatureExtractor 再把 D_text 投影到 DiT hidden_dim + """ + + def __init__(self, cfg: Blip2TextStateEncoderConfig): + self.cfg = cfg + self._tokenizer = None + self._model = None + + def _lazy_init(self): + if self._model is not None: + return + from transformers import AutoTokenizer, Blip2TextModelWithProjection + + self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path) + self._model = Blip2TextModelWithProjection.from_pretrained( + self.cfg.model_name_or_path, + torch_dtype=self.cfg.torch_dtype, + ) + self._model.eval() + for p in self._model.parameters(): + p.requires_grad_(False) + self._model.to(device=self.cfg.device) + + @torch.inference_mode() + def encode_texts(self, texts: List[str]) -> torch.Tensor: + self._lazy_init() + tok = self._tokenizer( + texts, + padding=True, + truncation=True, + max_length=self.cfg.max_length, + return_tensors="pt", + ) + tok = {k: v.to(self.cfg.device) for k, v in tok.items()} + out = self._model(**tok) + # (B, D_text) + return out.text_embeds.to(dtype=torch.float32, device="cpu") + + +def encode_state_text_tensor( + state_texts: list, + model_name_or_path: str = "Salesforce/blip2-itm-vit-g", + device: str = "cpu", + torch_dtype: torch.dtype = torch.float16, + max_length: int = 32, +) -> torch.Tensor: + """ + 将嵌套 list 的 state_texts(B,N)编码成 tensor: (B,N,D_text) float32 on CPU。 + """ + if not isinstance(state_texts, list) or not state_texts: + raise ValueError("state_texts must be a non-empty nested list (B,N)") + if not isinstance(state_texts[0], list): + raise ValueError("state_texts must be nested list like [[...], [...]]") + + encoder = Blip2TextStateEncoder( + Blip2TextStateEncoderConfig( + model_name_or_path=model_name_or_path, + device=device, + torch_dtype=torch_dtype, + max_length=max_length, + ) + ) + + # flatten unique texts to avoid redundant encode + all_texts = [] + for row in state_texts: + for t in row: + if not isinstance(t, str): + raise ValueError(f"state_text must be str, got: {type(t)}") + all_texts.append(t) + uniq = sorted(set(all_texts)) + emb = encoder.encode_texts(uniq) # (U, D) + table = {t: emb[i] for i, t in enumerate(uniq)} + + b = len(state_texts) + n = len(state_texts[0]) + out = torch.stack([torch.stack([table[t] for t in row], dim=0) for row in state_texts], dim=0) + # (B,N,D) + if out.shape[0] != b or out.shape[1] != n: + raise RuntimeError(f"unexpected encoded shape: {tuple(out.shape)}") + return out + diff --git a/diffsynth/utils/controlnet/__init__.py b/diffsynth/utils/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df23b6c61b99319f1f41d6448b0c52ffd03b9f25 --- /dev/null +++ b/diffsynth/utils/controlnet/__init__.py @@ -0,0 +1,2 @@ +from .controlnet_input import ControlNetInput +from .annotator import Annotator diff --git a/diffsynth/utils/controlnet/annotator.py b/diffsynth/utils/controlnet/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..06553e06d1c6d09f5a3deecfd4ea5604c5dd4352 --- /dev/null +++ b/diffsynth/utils/controlnet/annotator.py @@ -0,0 +1,62 @@ +from typing_extensions import Literal, TypeAlias + + +Processor_id: TypeAlias = Literal[ + "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" +] + +class Annotator: + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): + if not skip_processor: + if processor_id == "canny": + from controlnet_aux.processor import CannyDetector + self.processor = CannyDetector() + elif processor_id == "depth": + from controlnet_aux.processor import MidasDetector + self.processor = MidasDetector.from_pretrained(model_path).to(device) + elif processor_id == "softedge": + from controlnet_aux.processor import HEDdetector + self.processor = HEDdetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart": + from controlnet_aux.processor import LineartDetector + self.processor = LineartDetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart_anime": + from controlnet_aux.processor import LineartAnimeDetector + self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) + elif processor_id == "openpose": + from controlnet_aux.processor import OpenposeDetector + self.processor = OpenposeDetector.from_pretrained(model_path).to(device) + elif processor_id == "normal": + from controlnet_aux.processor import NormalBaeDetector + self.processor = NormalBaeDetector.from_pretrained(model_path).to(device) + elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint": + self.processor = None + else: + raise ValueError(f"Unsupported processor_id: {processor_id}") + else: + self.processor = None + + self.processor_id = processor_id + self.detect_resolution = detect_resolution + + def to(self,device): + if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"): + + self.processor.model.to(device) + + def __call__(self, image, mask=None): + width, height = image.size + if self.processor_id == "openpose": + kwargs = { + "include_body": True, + "include_hand": True, + "include_face": True + } + else: + kwargs = {} + if self.processor is not None: + detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) + image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) + image = image.resize((width, height)) + return image + diff --git a/diffsynth/utils/controlnet/controlnet_input.py b/diffsynth/utils/controlnet/controlnet_input.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2949bc5fab87c4779f5f4259943aa245faa4c4 --- /dev/null +++ b/diffsynth/utils/controlnet/controlnet_input.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from PIL import Image + + +@dataclass +class ControlNetInput: + controlnet_id: int = 0 + scale: float = 1.0 + start: float = 1.0 + end: float = 0.0 + image: Image.Image = None + inpaint_mask: Image.Image = None + processor_id: str = None diff --git a/diffsynth/utils/data/__init__.py b/diffsynth/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b9daa41bea9d36e012d52a1d280d1cf8d92850 --- /dev/null +++ b/diffsynth/utils/data/__init__.py @@ -0,0 +1,217 @@ +import imageio, os +import numpy as np +from PIL import Image +from tqdm import tqdm +import subprocess +import shutil + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return Image.open(self.file_list[item]).convert("RGB") + + def __del__(self): + pass + + +def crop_and_resize(image, height, width): + image = np.array(image) + image_height, image_width, _ = image.shape + if image_height / image_width < height / width: + croped_width = int(image_height / height * width) + left = (image_width - croped_width) // 2 + image = image[:, left: left+croped_width] + image = Image.fromarray(image).resize((width, height)) + else: + croped_height = int(image_width / width * height) + left = (image_height - croped_height) // 2 + image = image[left: left+croped_height, :] + image = Image.fromarray(image).resize((width, height)) + return image + + +class VideoData: + def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.set_shape(height, width) + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + height, width, _ = self.__getitem__(0).shape + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + width, height = frame.size + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = crop_and_resize(frame, self.height, self.width) + return frame + + def __del__(self): + pass + + def save_images(self, folder): + os.makedirs(folder, exist_ok=True) + for i in tqdm(range(self.__len__()), desc="Saving images"): + frame = self.__getitem__(i) + frame.save(os.path.join(folder, f"{i}.png")) + + +def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + +def save_frames(frames, save_path): + os.makedirs(save_path, exist_ok=True) + for i, frame in enumerate(tqdm(frames, desc="Saving images")): + frame.save(os.path.join(save_path, f"{i}.png")) + + +def merge_video_audio(video_path: str, audio_path: str): + # TODO: may need a in-python implementation to avoid subprocess dependency + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, + and overwrite the original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + + # check + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # execute the command + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + print(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + print(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + print(f"merge_video_audio failed with error: {e}") + + +def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None): + save_video(frames, save_path, fps, quality, ffmpeg_params) + merge_video_audio(save_path, audio_path) diff --git a/diffsynth/utils/lora/__init__.py b/diffsynth/utils/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb5901acba99ed8490079b8ebaeb6991ae3f59d --- /dev/null +++ b/diffsynth/utils/lora/__init__.py @@ -0,0 +1,3 @@ +from .general import GeneralLoRALoader +from .merge import merge_lora +from .reset_rank import reset_lora_rank \ No newline at end of file diff --git a/diffsynth/utils/lora/flux.py b/diffsynth/utils/lora/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..502c5fd449d619160aba29d33d3915ece1763cda --- /dev/null +++ b/diffsynth/utils/lora/flux.py @@ -0,0 +1,204 @@ +from .general import GeneralLoRALoader +import torch, math + + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight", + } + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().fuse_lora_to_base_model(model, state_dict_lora, alpha) + + def convert_state_dict(self, state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ diff --git a/diffsynth/utils/lora/general.py b/diffsynth/utils/lora/general.py new file mode 100644 index 0000000000000000000000000000000000000000..624549d518fb8f2a43b04625b268cbab4441a21a --- /dev/null +++ b/diffsynth/utils/lora/general.py @@ -0,0 +1,62 @@ +import torch + + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_up." in key: + lora_A_key = "lora_down" + lora_B_key = "lora_up" + else: + lora_A_key = "lora_A" + lora_B_key = "lora_B" + if lora_B_key not in key: + continue + keys = key.split(".") + if len(keys) > keys.index(lora_B_key) + 2: + keys.pop(keys.index(lora_B_key) + 1) + keys.pop(keys.index(lora_B_key)) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key)) + return lora_name_dict + + + def convert_state_dict(self, state_dict, suffix=".weight"): + name_dict = self.get_name_dict(state_dict) + state_dict_ = {} + for name in name_dict: + weight_up = state_dict[name_dict[name][0]] + weight_down = state_dict[name_dict[name][1]] + state_dict_[name + f".lora_B{suffix}"] = weight_up + state_dict_[name + f".lora_A{suffix}"] = weight_down + return state_dict_ + + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict, alpha=1.0): + updated_num = 0 + state_dict = self.convert_state_dict(state_dict) + lora_layer_names = set([i.replace(".lora_B.weight", "") for i in state_dict if i.endswith(".lora_B.weight")]) + for name, module in model.named_modules(): + if name in lora_layer_names: + weight_up = state_dict[name + ".lora_B.weight"].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict[name + ".lora_A.weight"].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict_base = module.state_dict() + state_dict_base["weight"] = state_dict_base["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict_base) + updated_num += 1 + print(f"{updated_num} tensors are fused by LoRA. Fused LoRA layers cannot be cleared by `pipe.clear_lora()`.") diff --git a/diffsynth/utils/lora/merge.py b/diffsynth/utils/lora/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..61904ff4bcebc6c344c23f26073aec292355217c --- /dev/null +++ b/diffsynth/utils/lora/merge.py @@ -0,0 +1,20 @@ +import torch +from typing import Dict, List + + +def merge_lora_weight(tensors_A, tensors_B): + lora_A = torch.concat(tensors_A, dim=0) + lora_B = torch.concat(tensors_B, dim=1) + return lora_A, lora_B + + +def merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1): + lora_merged = {} + keys = [i for i in loras[0].keys() if ".lora_A." in i] + for key in keys: + tensors_A = [lora[key] for lora in loras] + tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] + lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) + lora_merged[key] = lora_A * alpha + lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B + return lora_merged diff --git a/diffsynth/utils/lora/reset_rank.py b/diffsynth/utils/lora/reset_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..9522b043ff962bc050fa79596197f00abf3877b0 --- /dev/null +++ b/diffsynth/utils/lora/reset_rank.py @@ -0,0 +1,20 @@ +import torch + +def decomposite(tensor_A, tensor_B, rank): + dtype, device = tensor_A.dtype, tensor_A.device + weight = tensor_B @ tensor_A + U, S, V = torch.pca_lowrank(weight.float(), q=rank) + tensor_A = (V.T).to(dtype=dtype, device=device).contiguous() + tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous() + return tensor_A, tensor_B + +def reset_lora_rank(lora, rank): + lora_merged = {} + keys = [i for i in lora.keys() if ".lora_A." in i] + for key in keys: + tensor_A = lora[key] + tensor_B = lora[key.replace(".lora_A.", ".lora_B.")] + tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank) + lora_merged[key] = tensor_A + lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B + return lora_merged \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/__init__.py b/diffsynth/utils/state_dict_converters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffsynth/utils/state_dict_converters/flux2_text_encoder.py b/diffsynth/utils/state_dict_converters/flux2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0975e62a35021c697192ad054f0e3aff42289292 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux2_text_encoder.py @@ -0,0 +1,17 @@ +def Flux2TextEncoderStateDictConverter(state_dict): + rename_dict = { + "multi_modal_projector.linear_1.weight": "model.multi_modal_projector.linear_1.weight", + "multi_modal_projector.linear_2.weight": "model.multi_modal_projector.linear_2.weight", + "multi_modal_projector.norm.weight": "model.multi_modal_projector.norm.weight", + "multi_modal_projector.patch_merger.merging_layer.weight": "model.multi_modal_projector.patch_merger.merging_layer.weight", + "language_model.lm_head.weight": "lm_head.weight", + } + state_dict_ = {} + for k in state_dict: + k_ = k + k_ = k_.replace("language_model.model", "model.language_model") + k_ = k_.replace("vision_tower", "model.vision_tower") + if k_ in rename_dict: + k_ = rename_dict[k_] + state_dict_[k_] = state_dict[k] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_controlnet.py b/diffsynth/utils/state_dict_converters/flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..15f9447d22bc0ebc2dbb3d2eac8dbf0bd78e4151 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_controlnet.py @@ -0,0 +1,103 @@ +import torch + + +def FluxControlNetStateDictConverter(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_dit.py b/diffsynth/utils/state_dict_converters/flux_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbe460e2554d4be71a5e55ff2b38f000c4cc041 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_dit.py @@ -0,0 +1,92 @@ +import torch + + +def FluxDiTStateDictConverter(state_dict): + is_nexus_gen = sum([key.startswith("pipe.dit.") for key in state_dict]) > 0 + if is_nexus_gen: + dit_state_dict = {} + for key in state_dict: + if key.startswith('pipe.dit.'): + param = state_dict[key] + new_key = key.replace("pipe.dit.", "") + if new_key.startswith("final_norm_out.linear."): + param = torch.concat([param[3072:], param[:3072]], dim=0) + dit_state_dict[new_key] = param + return dit_state_dict + + rename_dict = { + "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", + "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", + "time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias", + "time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight", + "txt_in.bias": "context_embedder.bias", + "txt_in.weight": "context_embedder.weight", + "vector_in.in_layer.bias": "pooled_text_embedder.0.bias", + "vector_in.in_layer.weight": "pooled_text_embedder.0.weight", + "vector_in.out_layer.bias": "pooled_text_embedder.2.bias", + "vector_in.out_layer.weight": "pooled_text_embedder.2.weight", + "final_layer.linear.bias": "final_proj_out.bias", + "final_layer.linear.weight": "final_proj_out.weight", + "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias", + "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight", + "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias", + "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight", + "img_in.bias": "x_embedder.bias", + "img_in.weight": "x_embedder.weight", + "final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight", + "final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias", + } + suffix_rename_dict = { + "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight", + "img_attn.norm.query_norm.scale": "attn.norm_q_a.weight", + "img_attn.proj.bias": "attn.a_to_out.bias", + "img_attn.proj.weight": "attn.a_to_out.weight", + "img_attn.qkv.bias": "attn.a_to_qkv.bias", + "img_attn.qkv.weight": "attn.a_to_qkv.weight", + "img_mlp.0.bias": "ff_a.0.bias", + "img_mlp.0.weight": "ff_a.0.weight", + "img_mlp.2.bias": "ff_a.2.bias", + "img_mlp.2.weight": "ff_a.2.weight", + "img_mod.lin.bias": "norm1_a.linear.bias", + "img_mod.lin.weight": "norm1_a.linear.weight", + "txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight", + "txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight", + "txt_attn.proj.bias": "attn.b_to_out.bias", + "txt_attn.proj.weight": "attn.b_to_out.weight", + "txt_attn.qkv.bias": "attn.b_to_qkv.bias", + "txt_attn.qkv.weight": "attn.b_to_qkv.weight", + "txt_mlp.0.bias": "ff_b.0.bias", + "txt_mlp.0.weight": "ff_b.0.weight", + "txt_mlp.2.bias": "ff_b.2.bias", + "txt_mlp.2.weight": "ff_b.2.weight", + "txt_mod.lin.bias": "norm1_b.linear.bias", + "txt_mod.lin.weight": "norm1_b.linear.weight", + + "linear1.bias": "to_qkv_mlp.bias", + "linear1.weight": "to_qkv_mlp.weight", + "linear2.bias": "proj_out.bias", + "linear2.weight": "proj_out.weight", + "modulation.lin.bias": "norm.linear.bias", + "modulation.lin.weight": "norm.linear.weight", + "norm.key_norm.scale": "norm_k_a.weight", + "norm.query_norm.scale": "norm_q_a.weight", + } + state_dict_ = {} + for name in state_dict: + original_name = name + if name.startswith("model.diffusion_model."): + name = name[len("model.diffusion_model."):] + names = name.split(".") + if name in rename_dict: + rename = rename_dict[name] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "double_blocks": + rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "single_blocks": + if ".".join(names[2:]) in suffix_rename_dict: + rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + else: + pass + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_infiniteyou.py b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py new file mode 100644 index 0000000000000000000000000000000000000000..7025b392d54c5b4844ed3b3387bd010217897f4a --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py @@ -0,0 +1,2 @@ +def FluxInfiniteYouImageProjectorStateDictConverter(state_dict): + return state_dict['image_proj'] \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_ipadapter.py b/diffsynth/utils/state_dict_converters/flux_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..86dfb133655fbe9c33c84b419706a103cec96b1b --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_ipadapter.py @@ -0,0 +1,32 @@ +def FluxIpAdapterStateDictConverter(state_dict): + state_dict_ = {} + + if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict): + for name, param in state_dict["ip_adapter"].items(): + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = param + + if "image_proj" in state_dict: + for name, param in state_dict["image_proj"].items(): + name_ = "image_proj." + name + state_dict_[name_] = param + return state_dict_ + + for key, value in state_dict.items(): + if key.startswith("image_proj."): + state_dict_[key] = value + elif key.startswith("ip_adapter."): + new_key = key.replace("ip_adapter.", "ipadapter_modules.") + state_dict_[new_key] = value + else: + pass + + return state_dict_ + + +def SiglipStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if key.startswith("vision_model."): + new_state_dict[key] = state_dict[key] + return new_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py b/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..aa018aa5c570cc67f4856002e8f1f83f18998e07 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py @@ -0,0 +1,31 @@ +def FluxTextEncoderClipStateDictConverter(state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias", + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py b/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..d35eb831d2a7b1d48eee747d251d6cfb6ad508ef --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py @@ -0,0 +1,4 @@ +def FluxTextEncoderT5StateDictConverter(state_dict): + state_dict_ = {i: state_dict[i] for i in state_dict} + state_dict_["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_vae.py b/diffsynth/utils/state_dict_converters/flux_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6547f18f1e1cfe69d0cf4ef43860702812d25fab --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_vae.py @@ -0,0 +1,382 @@ +def FluxVAEEncoderStateDictConverter(state_dict): + rename_dict = { + "encoder.conv_in.bias": "conv_in.bias", + "encoder.conv_in.weight": "conv_in.weight", + "encoder.conv_out.bias": "conv_out.bias", + "encoder.conv_out.weight": "conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", + "encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", + "encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", + "encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", + "encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", + "encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", + "encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", + "encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", + "encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", + "encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", + "encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", + "encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", + "encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", + "encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", + "encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", + "encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", + "encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", + "encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", + "encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", + "encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", + "encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", + "encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", + "encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", + "encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", + "encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", + "encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", + "encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", + "encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", + "encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", + "encoder.norm_out.bias": "conv_norm_out.bias", + "encoder.norm_out.weight": "conv_norm_out.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEDecoderStateDictConverter(state_dict): + rename_dict = { + "decoder.conv_in.bias": "conv_in.bias", + "decoder.conv_in.weight": "conv_in.weight", + "decoder.conv_out.bias": "conv_out.bias", + "decoder.conv_out.weight": "conv_out.weight", + "decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias", + "decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight", + "decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias", + "decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight", + "decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias", + "decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight", + "decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias", + "decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight", + "decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias", + "decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight", + "decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias", + "decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight", + "decoder.norm_out.bias": "conv_norm_out.bias", + "decoder.norm_out.weight": "conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight", + "decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias", + "decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight", + "decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight", + "decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias", + "decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight", + "decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight", + "decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias", + "decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEEncoderStateDictConverterDiffusers(state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "quant_conv": "quant_conv", + "encoder.conv_in": "conv_in", + "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm", + "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q", + "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k", + "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v", + "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out", + "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1", + "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1", + "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2", + "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2", + "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1", + "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1", + "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2", + "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2", + "encoder.conv_norm_out": "conv_norm_out", + "encoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("encoder.down_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + return state_dict_ + + +def FluxVAEDecoderStateDictConverterDiffusers(state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "post_quant_conv": "post_quant_conv", + "decoder.conv_in": "conv_in", + "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm", + "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q", + "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k", + "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v", + "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out", + "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1", + "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1", + "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2", + "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2", + "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1", + "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1", + "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2", + "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2", + "decoder.conv_norm_out": "conv_norm_out", + "decoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("decoder.up_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/nexus_gen.py b/diffsynth/utils/state_dict_converters/nexus_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..aff853d0e76dd1f130ce44241462f61af84370db --- /dev/null +++ b/diffsynth/utils/state_dict_converters/nexus_gen.py @@ -0,0 +1,6 @@ +def NexusGenAutoregressiveModelStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + new_state_dict["model." + key] = value + return new_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/nexus_gen_projector.py b/diffsynth/utils/state_dict_converters/nexus_gen_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a44665551ba4a97d063de94f6025c8c48989fd --- /dev/null +++ b/diffsynth/utils/state_dict_converters/nexus_gen_projector.py @@ -0,0 +1,15 @@ +def NexusGenMergerStateDictConverter(state_dict): + merger_state_dict = {} + for key in state_dict: + if key.startswith('embedding_merger.'): + value = state_dict[key] + new_key = key.replace("embedding_merger.", "") + merger_state_dict[new_key] = value + return merger_state_dict + +def NexusGenAdapterStateDictConverter(state_dict): + adapter_state_dict = {} + for key in state_dict: + if key.startswith('adapter.'): + adapter_state_dict[key] = state_dict[key] + return adapter_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py b/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e8192a1f2a959685cf1fa5af40824bd896454141 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py @@ -0,0 +1,10 @@ +def QwenImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for k in state_dict: + v = state_dict[k] + if k.startswith("visual."): + k = "model." + k + elif k.startswith("model."): + k = k.replace("model.", "model.language_model.") + state_dict_[k] = v + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/step1x_connector.py b/diffsynth/utils/state_dict_converters/step1x_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..35a2a4167b16ea5cc16aaa1b0f20575bc2918bbf --- /dev/null +++ b/diffsynth/utils/state_dict_converters/step1x_connector.py @@ -0,0 +1,7 @@ +def Qwen2ConnectorStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("connector."): + name_ = name[len("connector."):] + state_dict_[name_] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py b/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea69f4e6696bbef6de197abaa031ea8cc5b398e --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py @@ -0,0 +1,6 @@ +def WanAnimateAdapterStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"): + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wan_video_dit.py b/diffsynth/utils/state_dict_converters/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..c7716dad52e42ebf76f98dd85511ac0a04b3d3b3 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_dit.py @@ -0,0 +1,83 @@ +def WanVideoDiTFromDiffusers(state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = state_dict[name] + return state_dict_ + + +def WanVideoDiTStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vace"): + continue + if name.split(".")[0] in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]: + continue + name_ = name + if name_.startswith("model."): + name_ = name_[len("model."):] + state_dict_[name_] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py b/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb7e9bfce50e88601f8876341ac56645a8e5913 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py @@ -0,0 +1,8 @@ +def WanImageEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("textual."): + continue + name_ = "model." + name + state_dict_[name_] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wan_video_mot.py b/diffsynth/utils/state_dict_converters/wan_video_mot.py new file mode 100644 index 0000000000000000000000000000000000000000..12b42d7db752fca1cb24c0f16217deab925916f5 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_mot.py @@ -0,0 +1,78 @@ +def WanVideoMotStateDictConverter(state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) + mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)} + state_dict_ = {} + for name in state_dict: + if "_mot_ref" not in name: + continue + param = state_dict[name] + name = name.replace("_mot_ref", "") + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + if name.split(".")[1].isdigit(): + block_id = int(name.split(".")[1]) + name = name.replace(str(block_id), str(mot_layers_mapping[block_id])) + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/wan_video_vace.py b/diffsynth/utils/state_dict_converters/wan_video_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfef6998f47ac7d3640b28b99f109e7f04baeba --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_vace.py @@ -0,0 +1,3 @@ +def VaceWanModelDictConverter(state_dict): + state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith("vace")} + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/wan_video_vae.py b/diffsynth/utils/state_dict_converters/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..76a430e1bd4575e0ae06234de23b620d4877566f --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_vae.py @@ -0,0 +1,7 @@ +def WanVideoVAEStateDictConverter(state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py b/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa12c0d4ff7fc166eea1f804cb645be9aa28776 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py @@ -0,0 +1,12 @@ +def WanS2VAudioEncoderStateDictConverter(state_dict): + rename_dict = { + "model.wav2vec2.encoder.pos_conv_embed.conv.weight_g": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0", + "model.wav2vec2.encoder.pos_conv_embed.conv.weight_v": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1", + } + state_dict_ = {} + for name in state_dict: + name_ = "model." + name + if name_ in rename_dict: + name_ = rename_dict[name_] + state_dict_[name_] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/xfuser/__init__.py b/diffsynth/utils/xfuser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13dd178e2d47bf58de1bfca6d21052d0563c70ca --- /dev/null +++ b/diffsynth/utils/xfuser/__init__.py @@ -0,0 +1 @@ +from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..11733132f95dea86e6fbe34900de40c40b92ba60 --- /dev/null +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -0,0 +1,145 @@ +import torch +from typing import Optional +from einops import rearrange +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + + +def initialize_usp(): + import torch.distributed as dist + from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment + dist.init_process_group(backend="nccl", init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), + ) + torch.cuda.set_device(dist.get_rank()) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] + + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) + return x_out.to(x.dtype) + +def usp_dit_forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # Context Parallel + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + + # unpatchify + x = self.unpatchify(x, (f, h, w)) + return x + + +def usp_attn_forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + + x = xFuserLongContextAttention()( + None, + query=q, + key=k, + value=v, + ) + x = x.flatten(2) + + del q, k, v + torch.cuda.empty_cache() + return self.o(x) \ No newline at end of file diff --git a/diffsynth_ext/__init__.py b/diffsynth_ext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..def289369bedc66c7022820c0ad6dc30bdf7fd0c --- /dev/null +++ b/diffsynth_ext/__init__.py @@ -0,0 +1 @@ +# Extension modules live here. This package does not modify base code. diff --git a/diffsynth_ext/comp_attn.py b/diffsynth_ext/comp_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..78cbdfa3abea6fb71db4789e36e308dd85c7c594 --- /dev/null +++ b/diffsynth_ext/comp_attn.py @@ -0,0 +1,469 @@ +import math +from dataclasses import dataclass +from typing import Optional, Sequence + +import torch +import torch.nn.functional as F + +from diffsynth.diffusion.base_pipeline import PipelineUnit +from diffsynth.pipelines.wan_video import ( + WanVideoPipeline, + WanVideoUnit_PromptEmbedder, + WanVideoUnit_CfgMerger, +) + + +@dataclass +class CompAttnConfig: + subjects: Sequence[str] + bboxes: Optional[Sequence] = None + enable_sci: bool = True + enable_lam: bool = True + temperature: float = 0.2 + apply_to_negative: bool = False + interpolate: bool = False + + +def find_subsequence_indices(prompt_ids: torch.Tensor, subject_ids: torch.Tensor, valid_len: int) -> list[int]: + if subject_ids.numel() == 0 or valid_len <= 0: + return [] + prompt_slice = prompt_ids[:valid_len].tolist() + subject_list = subject_ids.tolist() + span = len(subject_list) + if span > valid_len: + return [] + for start in range(valid_len - span + 1): + if prompt_slice[start:start + span] == subject_list: + return list(range(start, start + span)) + return [] + + +def build_subject_token_mask(indices_list: list[list[int]], seq_len: int) -> torch.Tensor: + mask = torch.zeros((len(indices_list), seq_len), dtype=torch.bool) + for i, indices in enumerate(indices_list): + if not indices: + continue + mask[i, torch.tensor(indices, dtype=torch.long)] = True + return mask + + +def compute_saliency(prompt_vecs: torch.Tensor, anchor_vecs: torch.Tensor, tau: float) -> torch.Tensor: + prompt_norm = prompt_vecs / (prompt_vecs.norm(dim=-1, keepdim=True) + 1e-8) + anchor_norm = anchor_vecs / (anchor_vecs.norm(dim=-1, keepdim=True) + 1e-8) + cosine = torch.matmul(prompt_norm, anchor_norm.transpose(0, 1)) + scores = torch.exp(cosine / tau) + diag = scores.diagonal() + denom = scores.sum(dim=1).clamp(min=1e-8) + return diag / denom + + +def compute_delta(anchor_vecs: torch.Tensor) -> torch.Tensor: + total = anchor_vecs.sum(dim=0, keepdim=True) + return anchor_vecs * anchor_vecs.shape[0] - total + + +def apply_sci(context: torch.Tensor, state: dict, timestep: torch.Tensor) -> torch.Tensor: + if state is None or not state.get("enable_sci", False): + return context + subject_mask = state.get("subject_token_mask") + delta = state.get("delta") + saliency = state.get("saliency") + if subject_mask is None or delta is None or saliency is None: + return context + if subject_mask.numel() == 0: + return context + t_scale = float(state.get("timestep_scale", 1000.0)) + t_value = float(timestep.reshape(-1)[0].item()) + t_ratio = max(0.0, min(1.0, t_value / t_scale)) + omega = 1.0 - t_ratio + delta = delta.to(device=context.device, dtype=context.dtype) + saliency = saliency.to(device=context.device, dtype=context.dtype) + scale = omega * (1.0 - saliency).unsqueeze(-1) + delta = delta * scale + mask = subject_mask.to(device=context.device) + token_delta = torch.matmul(mask.to(dtype=context.dtype).transpose(0, 1), delta) + apply_mask = state.get("apply_mask") + if apply_mask is not None: + apply_mask = apply_mask.to(device=context.device, dtype=context.dtype).view(-1, 1, 1) + else: + apply_mask = 1.0 + return context + token_delta.unsqueeze(0) * apply_mask + + +def interpolate_bboxes(bboxes: torch.Tensor, target_frames: int) -> torch.Tensor: + if bboxes.shape[2] == target_frames: + return bboxes + b, m, f, _ = bboxes.shape + coords = bboxes.reshape(b * m, f, 4).transpose(1, 2) + coords = F.interpolate(coords, size=target_frames, mode="linear", align_corners=True) + coords = coords.transpose(1, 2).reshape(b, m, target_frames, 4) + return coords + + +def build_layout_mask_from_bboxes( + bboxes: torch.Tensor, + grid_size: tuple[int, int, int], + image_size: tuple[int, int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + if bboxes is None: + return None + bboxes = bboxes.to(device=device, dtype=dtype) + b, m, f_layout, _ = bboxes.shape + f_grid, h_grid, w_grid = grid_size + height, width = image_size + layout = torch.zeros((b, m, f_grid, h_grid, w_grid), device=device, dtype=dtype) + for bi in range(b): + for mi in range(m): + for ti in range(f_layout): + pt = int(ti * f_grid / max(1, f_layout)) + pt = max(0, min(f_grid - 1, pt)) + x0, y0, x1, y1 = bboxes[bi, mi, ti] + x0 = float(x0) + y0 = float(y0) + x1 = float(x1) + y1 = float(y1) + if x1 <= x0 or y1 <= y0: + continue + px0 = int(math.floor(x0 / max(1.0, width) * w_grid)) + px1 = int(math.ceil(x1 / max(1.0, width) * w_grid)) + py0 = int(math.floor(y0 / max(1.0, height) * h_grid)) + py1 = int(math.ceil(y1 / max(1.0, height) * h_grid)) + px0 = max(0, min(w_grid, px0)) + px1 = max(0, min(w_grid, px1)) + py0 = max(0, min(h_grid, py0)) + py1 = max(0, min(h_grid, py1)) + if px1 <= px0 or py1 <= py0: + continue + layout[bi, mi, pt, py0:py1, px0:px1] = 1.0 + return layout.flatten(2) + + +def lam_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_heads: int, + state: dict, +) -> Optional[torch.Tensor]: + subject_mask = state.get("subject_token_mask_lam") or state.get("subject_token_mask") + layout_mask = state.get("layout_mask") + if subject_mask is None or layout_mask is None: + return None + if subject_mask.numel() == 0 or layout_mask.numel() == 0: + return None + b, q_len, dim = q.shape + _, k_len, _ = k.shape + if layout_mask.shape[-1] != q_len: + return None + if subject_mask.shape[-1] != k_len: + return None + head_dim = dim // num_heads + qh = q.view(b, q_len, num_heads, head_dim).transpose(1, 2) + kh = k.view(b, k_len, num_heads, head_dim).transpose(1, 2) + vh = v.view(b, k_len, num_heads, head_dim).transpose(1, 2) + attn_scores = torch.matmul(qh.float(), kh.float().transpose(-2, -1)) / math.sqrt(head_dim) + attn_max = attn_scores.max(dim=-1, keepdim=True).values + attn_min = attn_scores.min(dim=-1, keepdim=True).values + g_plus = attn_max - attn_scores + g_minus = attn_min - attn_scores + subject_mask = subject_mask.to(device=attn_scores.device) + layout_mask = layout_mask.to(device=attn_scores.device, dtype=attn_scores.dtype) + apply_mask = state.get("apply_mask") + if apply_mask is not None: + layout_mask = layout_mask * apply_mask.to(device=layout_mask.device, dtype=layout_mask.dtype).view(-1, 1, 1) + subject_any = subject_mask.any(dim=0) + bias = torch.zeros_like(attn_scores) + for k_idx in range(subject_mask.shape[0]): + mask_k = subject_mask[k_idx] + if not mask_k.any(): + continue + mask_other = subject_any & (~mask_k) + mask_k = mask_k.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len) + mask_other = mask_other.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len) + g_k = g_plus * mask_k + g_minus * mask_other + attn_k = attn_scores[..., subject_mask[k_idx]].mean(dim=-1).mean(dim=1) + adapt_mask = attn_k >= attn_k.mean(dim=-1, keepdim=True) + layout_k = layout_mask[:, k_idx] + adapt_f = adapt_mask.to(layout_k.dtype) + inter = (adapt_f * layout_k).sum(dim=-1) + union = (adapt_f + layout_k - adapt_f * layout_k).sum(dim=-1) + iou = inter / union.clamp(min=1e-6) + strength = (1.0 - iou).view(b, 1, 1, 1) + bias = bias + g_k * strength * layout_k.view(b, 1, q_len, 1) + attn_probs = torch.softmax(attn_scores + bias, dim=-1).to(vh.dtype) + out = torch.matmul(attn_probs, vh) + out = out.transpose(1, 2).reshape(b, q_len, dim) + return out + + +class CompAttnUnit(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "context": "context"}, + input_params_nega={"prompt": "negative_prompt", "context": "context"}, + output_params=("comp_attn_state",), + onload_model_names=("text_encoder",), + ) + + def _clean_text(self, pipe: WanVideoPipeline, text: str) -> str: + if getattr(pipe.tokenizer, "clean", None): + return pipe.tokenizer._clean(text) + return text + + def _tokenize_subject(self, pipe: WanVideoPipeline, text: str) -> torch.Tensor: + text = self._clean_text(pipe, text) + tokens = pipe.tokenizer.tokenizer(text, add_special_tokens=False, return_tensors="pt") + return tokens["input_ids"][0] + + def _normalize_bboxes(self, bboxes: Sequence) -> torch.Tensor: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32) + if bboxes.dim() == 2 and bboxes.shape[-1] == 4: + bboxes = bboxes.unsqueeze(0).unsqueeze(0) + elif bboxes.dim() == 3 and bboxes.shape[-1] == 4: + bboxes = bboxes.unsqueeze(0) + elif bboxes.dim() != 4 or bboxes.shape[-1] != 4: + raise ValueError(f"comp_attn_bboxes must be (..., 4), got shape {tuple(bboxes.shape)}") + return bboxes + + def process(self, pipe: WanVideoPipeline, prompt, context) -> dict: + config: Optional[CompAttnConfig] = getattr(pipe, "_comp_attn_config", None) + if context is None or prompt is None or config is None: + return {} + if not config.subjects: + return {} + negative_prompt = getattr(pipe, "_comp_attn_last_negative_prompt", None) + if (not config.apply_to_negative) and negative_prompt and prompt == negative_prompt: + return {} + pipe.load_models_to_device(self.onload_model_names) + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + prompt_ids = ids[0] + valid_len = int(mask[0].sum().item()) + indices_list = [] + valid_subjects = [] + for idx, subject in enumerate(config.subjects): + subject_ids = self._tokenize_subject(pipe, subject) + indices = find_subsequence_indices(prompt_ids, subject_ids, valid_len) + if not indices: + print(f"Comp-Attn: subject tokens not found in prompt: {subject}") + continue + indices_list.append(indices) + valid_subjects.append(idx) + if not indices_list: + return {} + subject_token_mask = build_subject_token_mask(indices_list, prompt_ids.shape[0]).to(device=context.device) + mask_float = subject_token_mask.to(dtype=context.dtype) + denom = mask_float.sum(dim=1, keepdim=True).clamp(min=1) + prompt_vecs = (mask_float @ context[0]) / denom + anchor_vecs = [] + for idx in valid_subjects: + subject = config.subjects[idx] + sub_ids, sub_mask = pipe.tokenizer(subject, return_mask=True, add_special_tokens=True) + sub_ids = sub_ids.to(pipe.device) + sub_mask = sub_mask.to(pipe.device) + emb = pipe.text_encoder(sub_ids, sub_mask) + pooled = (emb * sub_mask.unsqueeze(-1)).sum(dim=1) / sub_mask.sum(dim=1, keepdim=True).clamp(min=1) + anchor_vecs.append(pooled) + anchor_vecs = torch.cat(anchor_vecs, dim=0) + saliency = compute_saliency(prompt_vecs.float(), anchor_vecs.float(), float(config.temperature)).to(prompt_vecs.dtype) + delta = compute_delta(anchor_vecs.to(prompt_vecs.dtype)) + bboxes = None + if config.bboxes is not None: + bboxes = self._normalize_bboxes(config.bboxes) + if bboxes.shape[1] >= len(config.subjects): + bboxes = bboxes[:, valid_subjects] + if bboxes.shape[1] != len(valid_subjects): + print("Comp-Attn: bboxes subject count mismatch, disable LAM") + bboxes = None + if bboxes is not None and config.interpolate and getattr(pipe, "_comp_attn_num_frames", None) is not None: + bboxes = interpolate_bboxes(bboxes, int(pipe._comp_attn_num_frames)) + state = { + "enable_sci": bool(config.enable_sci), + "enable_lam": bool(config.enable_lam) and bboxes is not None, + "subject_token_mask": subject_token_mask, + "saliency": saliency, + "delta": delta, + "layout_bboxes": bboxes, + "timestep_scale": 1000.0, + "apply_to_negative": bool(config.apply_to_negative), + } + if negative_prompt and prompt == negative_prompt: + pipe._comp_attn_state_neg = state + else: + pipe._comp_attn_state_pos = state + return {"comp_attn_state": state} + + +class CompAttnMergeUnit(PipelineUnit): + def __init__(self): + super().__init__(input_params=("cfg_merge",), output_params=("comp_attn_state",)) + + def process(self, pipe: WanVideoPipeline, cfg_merge) -> dict: + if not cfg_merge: + return {} + state_pos = getattr(pipe, "_comp_attn_state_pos", None) + state_neg = getattr(pipe, "_comp_attn_state_neg", None) + merged = state_pos or state_neg + if merged is None: + return {} + merged = dict(merged) + apply_to_negative = bool(merged.get("apply_to_negative", False)) + merged["apply_mask"] = torch.tensor([1.0, 1.0 if apply_to_negative else 0.0]) + return {"comp_attn_state": merged} + + +def _patch_cross_attention(pipe: WanVideoPipeline): + for block in pipe.dit.blocks: + cross_attn = block.cross_attn + if getattr(cross_attn, "_comp_attn_patched", False): + continue + orig_forward = cross_attn.forward + + def forward_with_lam(self, x, y, _orig=orig_forward, _pipe=pipe): + state = getattr(_pipe, "_comp_attn_runtime_state", None) + if state is None or not state.get("enable_lam", False): + return _orig(x, y) + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + lam_out = lam_attention(q, k, v, self.num_heads, state) + if lam_out is None: + out = self.attn(q, k, v) + else: + out = lam_out + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + img_out = self.attn(q, k_img, v_img) + out = out + img_out + return self.o(out) + + cross_attn.forward = forward_with_lam.__get__(cross_attn, cross_attn.__class__) + cross_attn._comp_attn_patched = True + + +def _get_grid_from_latents(latents: torch.Tensor, patch_size: tuple[int, int, int]) -> tuple[int, int, int]: + f = latents.shape[2] // patch_size[0] + h = latents.shape[3] // patch_size[1] + w = latents.shape[4] // patch_size[2] + return f, h, w + + +def _wrap_model_fn(pipe: WanVideoPipeline): + if getattr(pipe, "_comp_attn_model_fn_patched", False): + return + orig_model_fn = pipe.model_fn + + def model_fn_wrapper(*args, **kwargs): + comp_attn_state = kwargs.pop("comp_attn_state", None) + height = kwargs.get("height") + width = kwargs.get("width") + num_frames = kwargs.get("num_frames") + if num_frames is not None: + pipe._comp_attn_num_frames = num_frames + if comp_attn_state is None: + return orig_model_fn(*args, **kwargs) + latents = kwargs.get("latents") + timestep = kwargs.get("timestep") + context = kwargs.get("context") + clip_feature = kwargs.get("clip_feature") + reference_latents = kwargs.get("reference_latents") + if context is not None and timestep is not None: + context = apply_sci(context, comp_attn_state, timestep) + kwargs["context"] = context + if comp_attn_state.get("enable_lam", False) and latents is not None and height is not None and width is not None: + f, h, w = _get_grid_from_latents(latents, pipe.dit.patch_size) + base_f = f + q_len = f * h * w + if reference_latents is not None: + q_len = (f + 1) * h * w + layout_mask = comp_attn_state.get("layout_mask") + layout_shape = comp_attn_state.get("layout_shape") + if layout_mask is None or layout_shape != (latents.shape[0], q_len): + layout_mask = build_layout_mask_from_bboxes( + comp_attn_state.get("layout_bboxes"), + (base_f, h, w), + (int(height), int(width)), + device=latents.device, + dtype=latents.dtype, + ) + if reference_latents is not None: + pad = torch.zeros((layout_mask.shape[0], layout_mask.shape[1], h * w), device=latents.device, dtype=latents.dtype) + layout_mask = torch.cat([pad, layout_mask], dim=-1) + if layout_mask.shape[0] != latents.shape[0]: + layout_mask = layout_mask.repeat(latents.shape[0], 1, 1) + comp_attn_state["layout_mask"] = layout_mask + comp_attn_state["layout_shape"] = (latents.shape[0], q_len) + subject_mask = comp_attn_state.get("subject_token_mask") + if subject_mask is not None and clip_feature is not None and pipe.dit.require_clip_embedding: + pad_len = clip_feature.shape[1] + pad = torch.zeros((subject_mask.shape[0], pad_len), dtype=torch.bool) + comp_attn_state["subject_token_mask_lam"] = torch.cat([pad, subject_mask.cpu()], dim=1) + if ( + latents is not None + and latents.shape[0] == 2 + and not comp_attn_state.get("apply_to_negative", False) + and "apply_mask" not in comp_attn_state + ): + comp_attn_state["apply_mask"] = torch.tensor([1.0, 0.0], device=latents.device, dtype=latents.dtype) + pipe._comp_attn_runtime_state = comp_attn_state + try: + return orig_model_fn(*args, **kwargs) + finally: + pipe._comp_attn_runtime_state = None + + pipe.model_fn = model_fn_wrapper + pipe._comp_attn_model_fn_patched = True + + +def attach_comp_attn(pipe: WanVideoPipeline) -> WanVideoPipeline: + if getattr(pipe, "_comp_attn_attached", False): + return pipe + prompt_idx = None + cfg_idx = None + for idx, unit in enumerate(pipe.units): + if prompt_idx is None and isinstance(unit, WanVideoUnit_PromptEmbedder): + prompt_idx = idx + if cfg_idx is None and isinstance(unit, WanVideoUnit_CfgMerger): + cfg_idx = idx + if prompt_idx is not None: + pipe.units.insert(prompt_idx + 1, CompAttnUnit()) + else: + pipe.units.append(CompAttnUnit()) + if cfg_idx is not None: + pipe.units.insert(cfg_idx + 1, CompAttnMergeUnit()) + else: + pipe.units.append(CompAttnMergeUnit()) + _patch_cross_attention(pipe) + _wrap_model_fn(pipe) + pipe._comp_attn_attached = True + return pipe + + +class CompAttnPipelineWrapper: + def __init__(self, pipe: WanVideoPipeline): + self.pipe = attach_comp_attn(pipe) + + def __getattr__(self, name): + return getattr(self.pipe, name) + + def __call__(self, prompt: str, negative_prompt: str = "", comp_attn: Optional[CompAttnConfig] = None, **kwargs): + num_frames = kwargs.get("num_frames") + if num_frames is not None: + self.pipe._comp_attn_num_frames = num_frames + self.pipe._comp_attn_config = comp_attn + self.pipe._comp_attn_last_prompt = prompt + self.pipe._comp_attn_last_negative_prompt = negative_prompt + return self.pipe(prompt=prompt, negative_prompt=negative_prompt, **kwargs) + + +def build_comp_attn_pipeline(*args, **kwargs) -> CompAttnPipelineWrapper: + pipe = WanVideoPipeline.from_pretrained(*args, **kwargs) + return CompAttnPipelineWrapper(pipe) diff --git a/docs/comp_attn_design.md b/docs/comp_attn_design.md new file mode 100644 index 0000000000000000000000000000000000000000..9b24241009f71680efdcef688bdada4bf7d7ef95 --- /dev/null +++ b/docs/comp_attn_design.md @@ -0,0 +1,317 @@ +# Comp-Attn 设计文档 + +> 基于论文 "Comp-Attn: Present-and-Align Attention for Compositional Video Generation" 的实现总结 + +## 概述 + +Comp-Attn 是一种 **composition-aware cross-attention** 变体,采用 **"Present-and-Align"** 范式,用于解决多主体视频生成中的两个核心问题: + +| 挑战 | 描述 | 解决方案 | +|-----|------|---------| +| **Subject Presence** | 并非所有主体都能在视频中呈现 | SCI(条件层面) | +| **Inter-subject Relations** | 主体间的交互和空间关系错位 | LAM(注意力层面) | + +## 核心组件 + +### 1. Subject-aware Condition Interpolation (SCI) + +SCI 在 **条件编码阶段** 增强每个主体的语义表示,确保所有主体都能被"召回"。 + +#### 工作原理 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 输入: prompt 中的主体 tokens + 独立编码的 anchor embeddings │ +├─────────────────────────────────────────────────────────────┤ +│ Step 1: 计算语义显著性 (Saliency) │ +│ - 比较 prompt 中主体 token 与 anchor 的余弦相似度 │ +│ - 使用 softmax(τ=0.2) 归一化得到显著性分数 │ +│ - 低显著性 = 主体在 prompt 上下文中被"淹没" │ +├─────────────────────────────────────────────────────────────┤ +│ Step 2: 计算语义差异 (Delta) │ +│ - delta_i = anchor_i * N - Σ(anchors) │ +│ - 表示每个主体相对于其他主体的独特语义 │ +├─────────────────────────────────────────────────────────────┤ +│ Step 3: 自适应插值 │ +│ - scale = ω * (1 - saliency) │ +│ - context' = context + delta * scale │ +│ - 显著性越低,增强越多;早期时间步增强越强 │ +└─────────────────────────────────────────────────────────────┘ +``` + +#### 关键代码 + +```python +# 显著性计算 +def compute_saliency(prompt_vecs, anchor_vecs, tau=0.2): + cosine = cosine_similarity(prompt_vecs, anchor_vecs) + scores = exp(cosine / tau) + return scores.diagonal() / scores.sum(dim=1) + +# 应用 SCI +omega = 1.0 - (timestep / 1000.0) # 时间步调度 +scale = omega * (1.0 - saliency) +context = context + delta * scale +``` + +### 2. Layout-forcing Attention Modulation (LAM) + +LAM 在 **注意力计算阶段** 动态调制注意力分布,使其与预定义的空间布局对齐。 + +#### 工作原理 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 输入: Q (视频 tokens), K/V (文本 tokens), Layout (bbox) │ +├─────────────────────────────────────────────────────────────┤ +│ Step 1: 计算原始 attention scores │ +│ attn_scores = Q @ K^T / sqrt(d) │ +├─────────────────────────────────────────────────────────────┤ +│ Step 2: 构建调制函数 │ +│ g_plus = max(scores) - scores (增强调制) │ +│ g_minus = min(scores) - scores (抑制调制) │ +├─────────────────────────────────────────────────────────────┤ +│ Step 3: 计算 IOU 引导的强度 │ +│ - adapt_mask: 当前注意力分布 > mean 的区域 │ +│ - layout_mask: 目标 bbox 区域 │ +│ - iou = intersection / union │ +│ - strength = 1 - iou (IOU 越低,调制越强) │ +├─────────────────────────────────────────────────────────────┤ +│ Step 4: 应用动态调制 │ +│ - 在 bbox 内增强对应主体 token (g_plus) │ +│ - 抑制其他主体 token (g_minus) │ +│ - bias = g_k * strength * layout_mask │ +│ - final_scores = attn_scores + bias │ +└─────────────────────────────────────────────────────────────┘ +``` + +#### LAM vs 传统 Layout Control + +| 方法 | 机制 | 缺点 | +|-----|------|-----| +| **硬掩码 (Hard Mask)** | 强制注意力只在 bbox 内 | 无法适应多样的物体形状 | +| **LAM (Ours)** | IOU 引导的动态软调制 | ✅ 灵活适应不同形状 | + +### 3. 关键帧插值 + +遵循论文附录 F 的设计,支持用 4 个关键帧描述运动轨迹,然后线性插值到所有帧: + +```python +# 输入: 4 个关键帧的 bbox +keyframe_bboxes = [frame_0, frame_1, frame_2, frame_3] + +# 输出: 81 帧的 bbox(线性插值) +all_bboxes = F.interpolate(keyframe_bboxes, size=81, mode='linear') +``` + +## 架构设计 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ WanVideoCompAttnPipeline │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────┐ ┌──────────────────┐ │ +│ │ CompAttnUnit │───▶│ CompAttnMerge │ │ +│ │ (SCI 预处理) │ │ (CFG 合并) │ │ +│ └──────────────────┘ └──────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ model_fn_wrapper (动态注入) │ │ +│ │ - apply_sci(): 修改 context │ │ +│ │ - build_layout_mask(): 构建空间掩码 │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ patched cross_attention │ │ +│ │ - lam_attention(): IOU 引导的注意力调制 │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Prompt-BBox 绑定机制 + +### 使用变量拼接(推荐) + +使用 Python 变量定义主体,通过 f-string 拼接 prompt,代码层面天然表达绑定关系: + +```python +# 定义主体变量 +subject0 = "red car" +subject1 = "blue bicycle" + +# 使用变量拼接 prompt +prompt = f"A {subject0} drives left, a {subject1} rides right" + +# subjects 和 bboxes 顺序一一对应 +subjects = [subject0, subject1] +bboxes = [bbox0, bbox1] # bbox0 对应 subject0, bbox1 对应 subject1 +``` + +### 绑定原理 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ 1. 用户定义 subjects 和 bboxes(顺序一一对应) │ +│ │ +│ subject0 = "red car" │ +│ subject1 = "blue bicycle" │ +│ subjects = [subject0, subject1] │ +│ bboxes = [bbox0, bbox1] │ +├─────────────────────────────────────────────────────────────────────────┤ +│ 2. Tokenize prompt,搜索每个 subject 的 token 位置 │ +│ │ +│ prompt = f"A {subject0} drives while a {subject1} rides..." │ +│ ↑↑↑↑↑↑↑ ↑↑↑↑↑↑↑↑↑↑↑↑ │ +│ token indices token indices │ +├─────────────────────────────────────────────────────────────────────────┤ +│ 3. 建立 subject_token_mask (关联 token 位置与 bbox) │ +│ │ +│ subject_token_mask[0, ...] = True # subject0 tokens │ +│ subject_token_mask[1, ...] = True # subject1 tokens │ +├─────────────────────────────────────────────────────────────────────────┤ +│ 4. 在推理时应用: │ +│ - SCI: 根据 mask 增强对应 token 的语义 │ +│ - LAM: 在 bbox 区域内增强对应 token 的注意力 │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +### 重要约束 + +1. **subjects 必须在 prompt 中出现**:使用变量拼接确保这一点 +2. **顺序一一对应**:`bboxes[i]` 对应 `subjects[i]` +3. **精确匹配**:subject 字符串需要能被分词器识别为完整的 token 序列 + +## 使用方法 + +### 基本用法 + +```python +from diffsynth.pipelines.wan_video_comp_attn import WanVideoCompAttnPipeline +from diffsynth.models.comp_attn_model import CompAttnConfig + +# 创建 pipeline +pipe = WanVideoCompAttnPipeline.from_pretrained(...) + +# 1. 定义主体变量 +subject0 = "red car" +subject1 = "blue bicycle" + +# 2. 定义运动轨迹 (4 个关键帧) +bbox0 = [(100, 250, 220, 380), (200, 250, 320, 380), (300, 250, 420, 380), (400, 250, 520, 380)] +bbox1 = [(350, 260, 410, 400), (450, 260, 510, 400), (550, 260, 610, 400), (650, 260, 710, 400)] + +# 3. 使用变量拼接 prompt +prompt = f"A {subject0} drives from left to center while a {subject1} rides to the right" + +# 4. 配置 Comp-Attn(顺序一一对应) +comp_attn = CompAttnConfig( + subjects=[subject0, subject1], # 变量列表 + bboxes=[bbox0, bbox1], # 对应的 bbox 列表 + enable_sci=True, + enable_lam=True, + interpolate=True, +) + +# 5. 生成视频 +video = pipe(prompt=prompt, comp_attn=comp_attn) +``` + +### 运动轨迹辅助函数 + +```python +def create_moving_bbox(start_x, end_x, y_center, box_width, box_height, num_keyframes=4): + """创建从 start_x 移动到 end_x 的关键帧 bbox 序列""" + keyframes = [] + for i in range(num_keyframes): + progress = i / (num_keyframes - 1) + center_x = start_x + (end_x - start_x) * progress + left = center_x - box_width / 2 + right = center_x + box_width / 2 + top = y_center - box_height / 2 + bottom = y_center + box_height / 2 + keyframes.append((left, top, right, bottom)) + return keyframes + +# 使用示例 +car_trajectory = create_moving_bbox( + start_x=100, end_x=500, + y_center=300, + box_width=120, box_height=80, +) +``` + +### 配置参数 + +| 参数 | 类型 | 默认值 | 说明 | +|-----|------|-------|------| +| `subjects` | List[str] | - | 主体名称列表(必须出现在 prompt 中) | +| `bboxes` | List | None | 每个主体每个关键帧的 bbox | +| `enable_sci` | bool | True | 是否启用 SCI | +| `enable_lam` | bool | True | 是否启用 LAM | +| `temperature` | float | 0.2 | 显著性计算的温度参数 τ | +| `apply_to_negative` | bool | False | 是否对负样本 prompt 应用 | +| `interpolate` | bool | False | 是否对关键帧 bbox 进行插值 | + +### BBox 格式 + +```python +# 格式: (left, top, right, bottom) 像素坐标 +# 支持多种输入形式: + +# 1. 静态布局(所有帧相同) +bbox = (100, 200, 300, 400) + +# 2. 关键帧布局(4帧,会被插值) +bboxes = [ + (100, 200, 300, 400), # 关键帧 0 + (150, 200, 350, 400), # 关键帧 1 + (200, 200, 400, 400), # 关键帧 2 + (250, 200, 450, 400), # 关键帧 3 +] +``` + +## 与论文的对应关系 + +| 论文章节 | 实现位置 | 状态 | +|---------|---------|-----| +| Sec 3.2 SCI | `compute_saliency()`, `compute_delta()`, `apply_sci()` | ✅ | +| Sec 3.3 LAM | `lam_attention()`, `build_layout_mask_from_bboxes()` | ✅ | +| Appendix D 关键帧插值 | `interpolate_bboxes()` | ✅ | +| Training-free 集成 | `patch_cross_attention()`, `wrap_model_fn()` | ✅ | + +## 性能特点 + +根据论文数据: + +- **T2V-CompBench** 性能提升: +15.7% (Wan2.1-14B), +11.7% (Wan2.2-A14B) +- **推理时间增加**: 仅 ~5% +- **兼容性**: Wan, CogVideoX, VideoCrafter2, FLUX + +## 注意事项 + +1. **主体名称必须精确匹配**: `subjects` 中的字符串必须能在 prompt 中找到对应的 token +2. **BBox 使用像素坐标**: 不是归一化坐标 +3. **关键帧数量**: 推荐使用 4 个关键帧描述运动轨迹 +4. **温度参数**: τ 过小会导致显著性估计不稳定,过大会削弱增强效果 +5. **State tokens**: per-frame state control adds extra context tokens and attention bias. Keep the number of states small to reduce overhead. + +## Per-frame State Control + +You can inject per-frame instance states (e.g., "running" -> "idle") with: + +- `state_texts`: list of state names per subject +- `state_weights`: per-frame weights `(M, F, S)` or `(B, M, F, S)` +- `state_scale`: bias strength for state tokens +- `state_template`: default `"{subject} is {state}"` + +The implementation appends state tokens to the context and applies a per-frame attention bias +based on the current token time index. + +## 参考 + +- 论文: "Comp-Attn: Present-and-Align Attention for Compositional Video Generation" +- 作者: Hongyu Zhang, Yufan Deng, et al. (Peking University, Tsinghua University) diff --git a/docs/en/API_Reference/core/attention.md b/docs/en/API_Reference/core/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..9ec3123065b7f691dd1cc25d31878cfcaae20991 --- /dev/null +++ b/docs/en/API_Reference/core/attention.md @@ -0,0 +1,79 @@ +# `diffsynth.core.attention`: Attention Mechanism Implementation + +`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation). + +## Attention Mechanism + +The attention mechanism is a model structure proposed in the paper ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). In the original paper, the attention mechanism is implemented according to the following formula: + +$$ +\text{Attention}(Q, K, V) = \text{Softmax}\left( + \frac{QK^T}{\sqrt{d_k}} +\right) +V. +$$ + +In `PyTorch`, it can be implemented with the following code: +```python +import torch + +def attention(query, key, value): + scale_factor = 1 / query.size(-1)**0.5 + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + +query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +output_1 = attention(query, key, value) +``` + +The dimensions of `query`, `key`, and `value` are $(b, n, s, d)$: +* $b$: Batch size +* $n$: Number of attention heads +* $s$: Sequence length +* $d$: Dimension of each attention head + +This computation does not include any trainable parameters. Modern transformer architectures will pass through Linear layers before and after this computation, but the "attention mechanism" discussed in this article refers only to the computation in the above code, not including these calculations. + +## More Efficient Implementations + +Note that the dimension of the Attention Score in the attention mechanism ( $\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$ in the formula, `attn_weight` in the code) is $(b, n, s, s)$, where the sequence length $s$ is typically very large, causing the time and space complexity of computation to reach quadratic level. Taking image generation models as an example, when the width and height of the image increase to 2 times, the sequence length increases to 4 times, and the computational load and memory requirements increase to 16 times. To avoid high computational costs, more efficient attention mechanism implementations are needed, including: +* Flash Attention 3: [GitHub](https://github.com/Dao-AILab/flash-attention), [Paper](https://arxiv.org/abs/2407.08608) +* Flash Attention 2: [GitHub](https://github.com/Dao-AILab/flash-attention), [Paper](https://arxiv.org/abs/2307.08691) +* Sage Attention: [GitHub](https://github.com/thu-ml/SageAttention), [Paper](https://arxiv.org/abs/2505.11594) +* xFormers: [GitHub](https://github.com/facebookresearch/xformers), [Documentation](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops) +* PyTorch: [GitHub](https://github.com/pytorch/pytorch), [Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) + +To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation). + +```python +from diffsynth.core.attention import attention_forward +import torch + +def attention(query, key, value): + scale_factor = 1 / query.size(-1)**0.5 + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + +query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +output_1 = attention(query, key, value) +output_2 = attention_forward(query, key, value) +print((output_1 - output_2).abs().mean()) +``` + +Please note that acceleration will introduce errors, but in most cases, the error is negligible. + +## Developer Guide + +When integrating new models into `DiffSynth-Studio`, developers can decide whether to call `attention_forward` in `diffsynth.core.attention`, but we expect models to prioritize calling this module as much as possible, so that new attention mechanism implementations can take effect directly on these models. + +## Best Practices + +**In most cases, we recommend directly using the native `PyTorch` implementation without installing any additional packages.** Although other attention mechanism implementations can accelerate, the acceleration effect is relatively limited, and in a few cases, compatibility and precision issues may arise. + +In addition, efficient attention mechanism implementations will gradually be integrated into `PyTorch`. The `scaled_dot_product_attention` in `PyTorch` version 2.9.0 has already integrated Flash Attention 2. We still provide this interface in `DiffSynth-Studio` to allow some aggressive acceleration schemes to quickly move toward application, even though they still need time to be verified for stability. \ No newline at end of file diff --git a/docs/en/API_Reference/core/data.md b/docs/en/API_Reference/core/data.md new file mode 100644 index 0000000000000000000000000000000000000000..0a6f11dd3e7f36c94f3c8b6c3d560143209a8db9 --- /dev/null +++ b/docs/en/API_Reference/core/data.md @@ -0,0 +1,151 @@ +# `diffsynth.core.data`: Data Processing Operators and Universal Dataset + +## Data Processing Operators + +### Available Data Processing Operators + +`diffsynth.core.data` provides a series of data processing operators for data processing, including: + +* Data format conversion operators + * `ToInt`: Convert to int format + * `ToFloat`: Convert to float format + * `ToStr`: Convert to str format + * `ToList`: Convert to list format, wrapping this data in a list + * `ToAbsolutePath`: Convert relative paths to absolute paths +* File loading operators + * `LoadImage`: Read image files + * `LoadVideo`: Read video files + * `LoadAudio`: Read audio files + * `LoadGIF`: Read GIF files + * `LoadTorchPickle`: Read binary files saved by [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) [This operator may cause code injection attacks in binary files, please use with caution!] +* Media file processing operators + * `ImageCropAndResize`: Crop and resize images +* Meta operators + * `SequencialProcess`: Route each data in the sequence to an operator + * `RouteByExtensionName`: Route to specific operators by file extension + * `RouteByType`: Route to specific operators by data type + +### Operator Usage + +Data operators are connected with the `>>` symbol to form data processing pipelines, for example: + +```python +from diffsynth.core.data.operators import * + +data = "image.jpg" +data_pipeline = ToAbsolutePath(base_path="/data") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512) +data = data_pipeline(data) +``` + +After passing through each operator, the data is processed in sequence: + +* `ToAbsolutePath(base_path="/data")`: `"/data/image.jpg"` +* `LoadImage()`: `` +* `ImageCropAndResize(max_pixels=512*512)`: `` + +We can compose functionally complete data pipelines, for example, the default video data operator for the universal dataset is: + +```python +RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), +]) +``` + +It includes the following logic: + +* If the data is of type `str` + * If it's a `"jpg", "jpeg", "png", "webp"` type file + * Load this image + * Crop and scale to a specific resolution + * Pack into a list, treating it as a single-frame video + * If it's a `"gif"` type file + * Load the GIF file content + * Crop and scale each frame to a specific resolution + * If it's a `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"` type file + * Load the video file content + * Crop and scale each frame to a specific resolution +* If the data is not of type `str`, an error is reported + +## Universal Dataset + +`diffsynth.core.data` provides a unified dataset implementation. The dataset requires the following parameters: + +* `base_path`: Root directory. If the dataset contains relative paths to image files, this field needs to be filled in to load the files pointed to by these paths +* `metadata_path`: Metadata directory, records the file paths of all metadata, supports `csv`, `json`, `jsonl` formats +* `repeat`: Data repetition count, defaults to 1, this parameter affects the number of training steps in an epoch +* `data_file_keys`: Data field names that need to be loaded, for example `(image, edit_image)` +* `main_data_operator`: Main loading operator, needs to assemble the data processing pipeline through data processing operators +* `special_operator_map`: Special operator mapping, operator mappings built for fields that require special processing + +### Metadata + +The dataset's `metadata_path` points to a metadata file, supporting `csv`, `json`, `jsonl` formats. The following provides examples: + +* `csv` format: High readability, does not support list data, small memory footprint + +```csv +image,prompt +image_1.jpg,"a dog" +image_2.jpg,"a cat" +``` + +* `json` format: High readability, supports list data, large memory footprint + +```json +[ + { + "image": "image_1.jpg", + "prompt": "a dog" + }, + { + "image": "image_2.jpg", + "prompt": "a cat" + } +] +``` + +* `jsonl` format: Low readability, supports list data, small memory footprint + +```json +{"image": "image_1.jpg", "prompt": "a dog"} +{"image": "image_2.jpg", "prompt": "a cat"} +``` + +How to choose the best metadata format? + +* If the data volume is large, reaching tens of millions, since `json` file parsing requires additional memory, it's not available. Please use `csv` or `jsonl` format +* If the dataset contains list data, such as edit models that require multiple images as input, since `csv` format cannot store list format data, it's not available. Please use `json` or `jsonl` format + +### Data Loading Logic + +When no additional settings are made, the dataset defaults to outputting data from the metadata set. Image and video file paths will be output in string format. To load these files, you need to set `data_file_keys`, `main_data_operator`, and `special_operator_map`. + +In the data processing flow, processing is done according to the following logic: +* If the field is in `special_operator_map`, call the corresponding operator in `special_operator_map` for processing +* If the field is not in `special_operator_map` + * If the field is in `data_file_keys`, call the `main_data_operator` operator for processing + * If the field is not in `data_file_keys`, no processing is done + +`special_operator_map` can be used to implement special data processing. For example, in the model [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B), the input character face video `animate_face_video` is processed at a fixed resolution, inconsistent with the output video. Therefore, this field is processed by a dedicated operator: + +```python +special_operator_map={ + "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), +} +``` + +### Other Notes + +When the data volume is too small, you can appropriately increase `repeat` to extend the training time of a single epoch, avoiding frequent model saving that generates considerable overhead. + +When data volume * `repeat` exceeds $10^9$, we observe that the dataset speed becomes significantly slower. This seems to be a `PyTorch` bug, and we are not sure if newer versions of `PyTorch` have fixed this issue. \ No newline at end of file diff --git a/docs/en/API_Reference/core/gradient.md b/docs/en/API_Reference/core/gradient.md new file mode 100644 index 0000000000000000000000000000000000000000..eeca81cac27028d51fc49e8088fa8faf30c23faf --- /dev/null +++ b/docs/en/API_Reference/core/gradient.md @@ -0,0 +1,69 @@ +# `diffsynth.core.gradient`: Gradient Checkpointing and Offload + +`diffsynth.core.gradient` provides encapsulated gradient checkpointing and its Offload version for model training. + +## Gradient Checkpointing + +Gradient checkpointing is a technique used to reduce memory usage during training. We provide an example to help you understand this technique. Here is a simple model structure: + +```python +import torch + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + return self.activation(x) + +model = ToyModel() +x = torch.randn((2, 3)) +y = model(x) +``` + +In this model structure, the input parameter $x$ passes through the Sigmoid activation function to obtain the output value $y=\frac{1}{1+e^{-x}}$. + +During the training process, assuming our loss function value is $\mathcal L$, when backpropagating gradients, we obtain $\frac{\partial \mathcal L}{\partial y}$. At this point, we need to calculate $\frac{\partial \mathcal L}{\partial x}$. It's not difficult to find that $\frac{\partial y}{\partial x}=y(1-y)$, and thus $\frac{\partial \mathcal L}{\partial x}=\frac{\partial \mathcal L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial \mathcal L}{\partial y}y(1-y)$. If we save the value of $y$ during the model's forward propagation and directly compute $y(1-y)$ during gradient backpropagation, this will avoid complex exp computations, speeding up the calculation. However, this requires additional memory to store the intermediate variable $y$. + +When gradient checkpointing is not enabled, the training framework will default to storing all intermediate variables that assist gradient computation, thereby achieving optimal computational speed. When gradient checkpointing is enabled, intermediate variables are not stored, but the input parameter $x$ is still stored, reducing memory usage. During gradient backpropagation, these variables need to be recomputed, slowing down the calculation. + +## Enabling Gradient Checkpointing and Its Offload + +`gradient_checkpoint_forward` in `diffsynth.core.gradient` implements gradient checkpointing and its Offload. Refer to the following code for calling: + +```python +import torch +from diffsynth.core.gradient import gradient_checkpoint_forward + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + return self.activation(x) + +model = ToyModel() +x = torch.randn((2, 3)) +y = gradient_checkpoint_forward( + model, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + x=x, +) +``` + +* When `use_gradient_checkpointing=False` and `use_gradient_checkpointing_offload=False`, the computation process is exactly the same as the original computation, not affecting the model's inference and training. You can directly integrate it into your code. +* When `use_gradient_checkpointing=True` and `use_gradient_checkpointing_offload=False`, gradient checkpointing is enabled. +* When `use_gradient_checkpointing_offload=True`, gradient checkpointing is enabled, and all gradient checkpoint input parameters are stored in memory, further reducing memory usage and slowing down computation. + +## Best Practices + +> Q: Where should gradient checkpointing be enabled? +> +> A: When enabling gradient checkpointing for the entire model, computational efficiency and memory usage are not optimal. We need to set fine-grained gradient checkpoints, but we don't want to add too much complicated code to the framework. Therefore, we recommend implementing it in the `model_fn` of `Pipeline`, for example, `model_fn_qwen_image` in `diffsynth/pipelines/qwen_image.py`, enabling gradient checkpointing at the Block level without modifying any code in the model structure. + +> Q: When should gradient checkpointing be enabled? +> +> A: As model parameters become increasingly large, gradient checkpointing has become a necessary training technique. Gradient checkpointing usually needs to be enabled. Gradient checkpointing Offload should only be enabled in models where activation values occupy excessive memory (such as video generation models). \ No newline at end of file diff --git a/docs/en/API_Reference/core/loader.md b/docs/en/API_Reference/core/loader.md new file mode 100644 index 0000000000000000000000000000000000000000..1dccf5f495189484d0dff7f213e740cee492a969 --- /dev/null +++ b/docs/en/API_Reference/core/loader.md @@ -0,0 +1,141 @@ +# `diffsynth.core.loader`: Model Download and Loading + +This document introduces the model download and loading functionalities in `diffsynth.core.loader`. + +## ModelConfig + +`ModelConfig` in `diffsynth.core.loader` is used to annotate model download sources, local paths, VRAM management configurations, and other information. + +### Downloading and Loading Models from Remote Sources + +Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). + +By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. + +```python +from diffsynth.core import ModelConfig + +config = ModelConfig( + model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", + origin_file_pattern="model.safetensors", +) +# Download models +config.download_if_necessary() +print(config.path) +``` + +After calling `download_if_necessary`, the model will be automatically downloaded, and the path will be returned to `config.path`. + +### Loading Models from Local Paths + +If loading models from local paths, you need to fill in `path`: + +```python +from diffsynth.core import ModelConfig + +config = ModelConfig(path="models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors") +``` + +If the model contains multiple shard files, input them in list form: + +```python +from diffsynth.core import ModelConfig + +config = ModelConfig(path=[ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +]) +``` + +### VRAM Management Configuration + +`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods) for details. + +## Model File Loading + +`diffsynth.core.loader` provides a unified `load_state_dict` for loading state dicts from model files. + +Loading a single model file: + +```python +from diffsynth.core import load_state_dict + +state_dict = load_state_dict("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors") +``` + +Loading multiple model files (merged into one state dict): + +```python +from diffsynth.core import load_state_dict + +state_dict = load_state_dict([ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +]) +``` + +## Model Hash + +Model hash is used to determine the model type. The hash value can be obtained through `hash_model_file`: + +```python +from diffsynth.core import hash_model_file + +print(hash_model_file("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")) +``` + +The hash value of multiple model files can also be calculated, which is equivalent to calculating the model hash value after merging the state dict: + +```python +from diffsynth.core import hash_model_file + +print(hash_model_file([ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +])) +``` + +The model hash value is only related to the keys and tensor shapes in the state dict of the model file, and is unrelated to the numerical values of the model parameters, file saving time, and other information. When calculating the model hash value of `.safetensors` format files, `hash_model_file` is almost instantly completed without reading the model parameters. However, when calculating the model hash value of `.bin`, `.pth`, `.ckpt`, and other binary files, all model parameters need to be read, so **we do not recommend developers to continue using these formats of files.** + +By [writing model Config](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it. + +## Model Loading + +`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](/docs/en/API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](/docs/en/API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode. + +Here is a usage example with Disk Offload enabled: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] + +model = load_model( + QwenImageDiT, + model_path, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config={ + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +``` \ No newline at end of file diff --git a/docs/en/API_Reference/core/vram.md b/docs/en/API_Reference/core/vram.md new file mode 100644 index 0000000000000000000000000000000000000000..79e51fc2ff7ecf2b946ee2ad787e2e68cd204092 --- /dev/null +++ b/docs/en/API_Reference/core/vram.md @@ -0,0 +1,66 @@ +# `diffsynth.core.vram`: VRAM Management + +This document introduces the underlying VRAM management functionalities in `diffsynth.core.vram`. If you wish to use these functionalities in other codebases, you can refer to this document. + +## Skipping Model Parameter Initialization + +When loading models in `PyTorch`, model parameters default to occupying VRAM or memory and initializing parameters, but these parameters will be overwritten when loading pretrained weights, leading to redundant computations. `PyTorch` does not provide an interface to skip these redundant computations. We provide `skip_model_initialization` in `diffsynth.core.vram` to skip model parameter initialization. + +Default model loading approach: + +```python +from diffsynth.core import load_state_dict +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet + +model = QwenImageBlockWiseControlNet() # Slow +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = load_state_dict(path, device="cpu") +model.load_state_dict(state_dict, assign=True) +``` + +Model loading approach that skips parameter initialization: + +```python +from diffsynth.core import load_state_dict, skip_model_initialization +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet + +with skip_model_initialization(): + model = QwenImageBlockWiseControlNet() # Fast +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = load_state_dict(path, device="cpu") +model.load_state_dict(state_dict, assign=True) +``` + +In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](/docs/en/Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach. + +## State Dict Disk Mapping + +For pretrained weight files of a model, if we only need to read a set of parameters rather than all parameters, State Dict Disk Mapping can accelerate this process. We provide `DiskMap` in `diffsynth.core.vram` for on-demand loading of model parameters. + +Default weight loading approach: + +```python +from diffsynth.core import load_state_dict + +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = load_state_dict(path, device="cpu") # Slow +print(state_dict["img_in.weight"]) +``` + +Using `DiskMap` to load only specific parameters: + +```python +from diffsynth.core import DiskMap + +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = DiskMap(path, device="cpu") # Fast +print(state_dict["img_in.weight"]) +``` + +`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](/docs/en/Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload. + +`DiskMap` is a functionality implemented using the characteristics of `.safetensors` files. Therefore, when using `.bin`, `.pth`, `.ckpt`, and other binary files, model parameters are fully loaded, which causes Disk Offload to not support these formats of files. **We do not recommend developers to continue using these formats of files.** + +## Replacable Modules for VRAM Management + +When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](/docs/en/Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes). \ No newline at end of file diff --git a/docs/en/Developer_Guide/Building_a_Pipeline.md b/docs/en/Developer_Guide/Building_a_Pipeline.md new file mode 100644 index 0000000000000000000000000000000000000000..7d5e7856ef24bdf1078fc5c210b79a7a7c3c8d0a --- /dev/null +++ b/docs/en/Developer_Guide/Building_a_Pipeline.md @@ -0,0 +1,250 @@ +# Building a Pipeline + +After [integrating the required models for the Pipeline](/docs/en/Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction. + +The `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components: + +* `__init__` +* `from_pretrained` +* `__call__` +* `units` +* `model_fn` + +## `__init__` + +In `__init__`, the `Pipeline` is initialized. Here is a simple implementation: + +```python +import torch +from PIL import Image +from typing import Union +from tqdm import tqdm +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit +from ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model + +class NewDiffSynthPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler() + self.text_encoder: XXX_Model = None + self.dit: YYY_Model = None + self.vae: ZZZ_Model = None + self.in_iteration_models = ("dit",) + self.units = [ + NewDiffSynthPipelineUnit_xxx(), + ... + ] + self.model_fn = model_fn_new +``` + +This includes the following parts: + +* `scheduler`: Scheduler, used to control the coefficients in the iterative formula during inference, controlling the noise content at each step. +* `text_encoder`, `dit`, `vae`: Models. Since [Latent Diffusion](https://arxiv.org/abs/2112.10752) was proposed, this three-stage model architecture has become the mainstream Diffusion model architecture. However, this is not immutable, and any number of models can be added to the `Pipeline`. +* `in_iteration_models`: Iteration models. This tuple marks which models will be called during iteration. +* `units`: Pre-processing units for model iteration. See [`units`](#units) for details. +* `model_fn`: The `forward` function of the denoising model during iteration. See [`model_fn`](#model_fn) for details. + +> Q: Model loading does not occur in `__init__`, why initialize each model as `None` here? +> +> A: By annotating the type of each model here, the code editor can provide code completion prompts based on each model, facilitating subsequent development. + +## `from_pretrained` + +`from_pretrained` is responsible for loading the required models to make the `Pipeline` callable. Here is a simple implementation: + +```python + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + vram_limit: float = None, + ): + # Initialize pipeline + pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("xxx_text_encoder") + pipe.dit = model_pool.fetch_model("yyy_dit") + pipe.vae = model_pool.fetch_model("zzz_vae") + # If necessary, load tokenizers here. + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe +``` + +Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config). + +Some models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models. + +## `__call__` + +`__call__` implements the entire generation process of the Pipeline. Below is a common generation process template. Developers can modify it based on their needs. + +```python + @torch.no_grad() + def __call__( + self, + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + input_image: Image.Image = None, + denoising_strength: float = 1.0, + height: int = 1328, + width: int = 1328, + seed: int = None, + rand_device: str = "cpu", + num_inference_steps: int = 30, + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps( + num_inference_steps, + denoising_strength=denoising_strength + ) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, + "denoising_strength": denoising_strength, + "height": height, + "width": width, + "seed": seed, + "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id) + if cfg_scale != 1.0: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"], device=self.device) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image +``` + +## `units` + +`units` contains all the preprocessing processes, such as: width/height checking, prompt encoding, initial noise generation, etc. In the entire model preprocessing process, data is abstracted into three mutually exclusive parts, stored in corresponding dictionaries: + +* `inputs_shared`: Shared inputs, parameters unrelated to [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598) (CFG for short). +* `inputs_posi`: Positive side inputs for Classifier-Free Guidance, containing content related to positive prompts. +* `inputs_nega`: Negative side inputs for Classifier-Free Guidance, containing content related to negative prompts. + +Pipeline Unit implementations include three types: direct mode, CFG separation mode, and takeover mode. + +If some calculations are unrelated to CFG, direct mode can be used, for example, Qwen-Image's random noise initialization: + +```python +class QwenImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} +``` + +If some calculations are related to CFG and need to separately process positive and negative prompts, but the input parameters on both sides are the same, CFG separation mode can be used, for example, Qwen-image's prompt encoding: + +```python +class QwenImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + input_params=("edit_image",), + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: + pipe.load_models_to_device(self.onload_model_names) + # Do something + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} +``` + +If some calculations need global information, takeover mode is required, for example, Qwen-Image's entity partition control: + +```python +class QwenImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): + # Do something + return inputs_shared, inputs_posi, inputs_nega +``` + +The following are the parameter configurations required for Pipeline Unit: + +* `seperate_cfg`: Whether to enable CFG separation mode +* `take_over`: Whether to enable takeover mode +* `input_params`: Shared input parameters +* `output_params`: Output parameters +* `input_params_posi`: Positive side input parameters +* `input_params_nega`: Negative side input parameters +* `onload_model_names`: Names of model components to be called + +When designing `unit`, please try to follow these principles: + +* Default fallback: For optional function `unit` input parameters, the default is `None` rather than `False` or other values. Please provide fallback processing for this default value. +* Parameter triggering: Some Adapter models may not be loaded, such as ControlNet. The corresponding `unit` should control triggering based on whether the parameter input is `None` rather than whether the model is loaded. For example, when the user inputs `controlnet_image` but does not load the ControlNet model, the code should give an error rather than ignore these input parameters and continue execution. +* Simplicity first: Use direct mode as much as possible, only use takeover mode when the function cannot be implemented. +* VRAM efficiency: When calling models in `unit`, please use `pipe.load_models_to_device(self.onload_model_names)` to activate the corresponding models. Do not call other models outside `onload_model_names`. After `unit` calculation is completed, do not manually release VRAM with `pipe.load_models_to_device([])`. + +> Q: Some parameters are not called during the inference process, such as `output_params`. Is it still necessary to configure them? +> +> A: These parameters will not affect the inference process, but they will affect some experimental features. Therefore, we recommend configuring them properly. For example, "split training" - we can complete the preprocessing offline during training, but some model calculations that require gradient backpropagation cannot be split. These parameters are used to build computational graphs to infer which calculations can be split. + +## `model_fn` + +`model_fn` is the unified `forward` interface during iteration. For models where the open-source ecosystem is not yet formed, you can directly use the denoising model's `forward`, for example: + +```python +def model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs): + return dit(latents, prompt_emb, timestep) +``` + +For models with rich open-source ecosystems, `model_fn` usually contains complex and chaotic cross-model inference. Taking `diffsynth/pipelines/qwen_image.py` as an example, the additional calculations implemented in this function include: entity partition control, three types of ControlNet, Gradient Checkpointing, etc. Developers need to be extra careful when implementing this part to avoid conflicts between module functions. \ No newline at end of file diff --git a/docs/en/Developer_Guide/Enabling_VRAM_management.md b/docs/en/Developer_Guide/Enabling_VRAM_management.md new file mode 100644 index 0000000000000000000000000000000000000000..9bdd49f10ef450ee773a176070019adf43081891 --- /dev/null +++ b/docs/en/Developer_Guide/Enabling_VRAM_management.md @@ -0,0 +1,455 @@ +# Fine-Grained VRAM Management Scheme + +This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). + +## How Much VRAM Does a 20B Model Need? + +Taking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM. + +```python +from diffsynth.core import load_model +from diffsynth.models.qwen_image_dit import QwenImageDiT +from modelscope import snapshot_download +import torch + +snapshot_download( + model_id="Qwen/Qwen-Image", + local_dir="models/Qwen/Qwen-Image", + allow_file_pattern="transformer/*" +) +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda") +with torch.no_grad(): + output = model(**inputs) +``` + +## Writing Fine-Grained VRAM Management Scheme + +To write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure: + +``` +QwenImageDiT( + (pos_embed): QwenEmbedRope() + (time_text_embed): TimestepEmbeddings( + (time_proj): TemporalTimesteps() + (timestep_embedder): DiffusersCompatibleTimestepProj( + (linear_1): Linear(in_features=256, out_features=3072, bias=True) + (act): SiLU() + (linear_2): Linear(in_features=3072, out_features=3072, bias=True) + ) + ) + (txt_norm): RMSNorm() + (img_in): Linear(in_features=64, out_features=3072, bias=True) + (txt_in): Linear(in_features=3584, out_features=3072, bias=True) + (transformer_blocks): ModuleList( + (0-59): 60 x QwenImageTransformerBlock( + (img_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (attn): QwenDoubleStreamAttention( + (to_q): Linear(in_features=3072, out_features=3072, bias=True) + (to_k): Linear(in_features=3072, out_features=3072, bias=True) + (to_v): Linear(in_features=3072, out_features=3072, bias=True) + (norm_q): RMSNorm() + (norm_k): RMSNorm() + (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True) + (norm_added_q): RMSNorm() + (norm_added_k): RMSNorm() + (to_out): Sequential( + (0): Linear(in_features=3072, out_features=3072, bias=True) + ) + (to_add_out): Linear(in_features=3072, out_features=3072, bias=True) + ) + (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (img_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + (txt_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + ) + ) + (norm_out): AdaLayerNorm( + (linear): Linear(in_features=3072, out_features=6144, bias=True) + (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + ) + (proj_out): Linear(in_features=3072, out_features=64, bias=True) +) +``` + +In VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`. + +`diffsynth.core.vram` provides two replacement modules for VRAM management: +* `AutoWrappedLinear`: Used to replace `Linear` layers +* `AutoWrappedModule`: Used to replace any other layer + +Write a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules: + +```python +module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, +} +``` + +In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods). + +Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu") +enable_vram_management( + model, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +The above code only requires 2G VRAM to run the `forward` of a 20B model. + +## Disk Offload + +[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model( + QwenImageDiT, + model_path, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config={ + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. + +If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub. + +## Writing Default Configuration + +To make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is: + +```python +"diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", +} +```# Fine-Grained VRAM Management Scheme + +This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). + +## How Much VRAM Does a 20B Model Need? + +Taking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM. + +```python +from diffsynth.core import load_model +from diffsynth.models.qwen_image_dit import QwenImageDiT +from modelscope import snapshot_download +import torch + +snapshot_download( + model_id="Qwen/Qwen-Image", + local_dir="models/Qwen/Qwen-Image", + allow_file_pattern="transformer/*" +) +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda") +with torch.no_grad(): + output = model(**inputs) +``` + +## Writing Fine-Grained VRAM Management Scheme + +To write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure: + +``` +QwenImageDiT( + (pos_embed): QwenEmbedRope() + (time_text_embed): TimestepEmbeddings( + (time_proj): TemporalTimesteps() + (timestep_embedder): DiffusersCompatibleTimestepProj( + (linear_1): Linear(in_features=256, out_features=3072, bias=True) + (act): SiLU() + (linear_2): Linear(in_features=3072, out_features=3072, bias=True) + ) + ) + (txt_norm): RMSNorm() + (img_in): Linear(in_features=64, out_features=3072, bias=True) + (txt_in): Linear(in_features=3584, out_features=3072, bias=True) + (transformer_blocks): ModuleList( + (0-59): 60 x QwenImageTransformerBlock( + (img_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (attn): QwenDoubleStreamAttention( + (to_q): Linear(in_features=3072, out_features=3072, bias=True) + (to_k): Linear(in_features=3072, out_features=3072, bias=True) + (to_v): Linear(in_features=3072, out_features=3072, bias=True) + (norm_q): RMSNorm() + (norm_k): RMSNorm() + (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True) + (norm_added_q): RMSNorm() + (norm_added_k): RMSNorm() + (to_out): Sequential( + (0): Linear(in_features=3072, out_features=3072, bias=True) + ) + (to_add_out): Linear(in_features=3072, out_features=3072, bias=True) + ) + (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (img_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + (txt_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + ) + ) + (norm_out): AdaLayerNorm( + (linear): Linear(in_features=3072, out_features=6144, bias=True) + (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + ) + (proj_out): Linear(in_features=3072, out_features=64, bias=True) +) +``` + +In VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`. + +`diffsynth.core.vram` provides two replacement modules for VRAM management: +* `AutoWrappedLinear`: Used to replace `Linear` layers +* `AutoWrappedModule`: Used to replace any other layer + +Write a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules: + +```python +module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, +} +``` + +In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods). + +Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu") +enable_vram_management( + model, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +The above code only requires 2G VRAM to run the `forward` of a 20B model. + +## Disk Offload + +[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model( + QwenImageDiT, + model_path, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config={ + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. + +If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub. + +## Writing Default Configuration + +To make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is: + +```python +"diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", +} +``` \ No newline at end of file diff --git a/docs/en/Developer_Guide/Integrating_Your_Model.md b/docs/en/Developer_Guide/Integrating_Your_Model.md new file mode 100644 index 0000000000000000000000000000000000000000..ae5e6f2a9d584a95227ce37d43982e7f05f1be7f --- /dev/null +++ b/docs/en/Developer_Guide/Integrating_Your_Model.md @@ -0,0 +1,186 @@ +# Integrating Model Architecture + +This document introduces how to integrate models into the `DiffSynth-Studio` framework for use by modules such as `Pipeline`. + +## Step 1: Integrate Model Architecture Code + +All model architecture implementations in `DiffSynth-Studio` are unified in `diffsynth/models`. Each `.py` code file implements a model architecture, and all models are loaded through `ModelPool` in `diffsynth/models/model_loader.py`. When integrating new model architectures, please create a new `.py` file under this path. + +```shell +diffsynth/models/ +├── general_modules.py +├── model_loader.py +├── qwen_image_controlnet.py +├── qwen_image_dit.py +├── qwen_image_text_encoder.py +├── qwen_image_vae.py +└── ... +``` + +In most cases, we recommend integrating models in native `PyTorch` code form, with the model architecture class directly inheriting from `torch.nn.Module`, for example: + +```python +import torch + +class NewDiffSynthModel(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + x = self.linear(x) + x = self.activation(x) + return x +``` + +If the model architecture implementation contains additional dependencies, we strongly recommend removing them, otherwise this will cause heavy package dependency issues. In our existing models, Qwen-Image's Blockwise ControlNet is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_controlnet.py`. + +If the model has been integrated by Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index), [`diffusers`](https://huggingface.co/docs/diffusers/main/index), etc.), we can integrate the model in a simpler way: + +
+Integrating Huggingface Library Style Model Architecture Code + +The loading method for these models in Huggingface Library is: + +```python +from transformers import XXX_Model + +model = XXX_Model.from_pretrained("path_to_your_model") +``` + +`DiffSynth-Studio` does not support loading models through `from_pretrained` because this conflicts with VRAM management and other functions. Please rewrite the model architecture in the following format: + +```python +import torch + +class DiffSynth_XXX_Model(torch.nn.Module): + def __init__(self): + super().__init__() + from transformers import XXX_Config, XXX_Model + config = XXX_Config(**{ + "architectures": ["XXX_Model"], + "other_configs": "Please copy and paste the other configs here.", + }) + self.model = XXX_Model(config) + + def forward(self, x): + outputs = self.model(x) + return outputs +``` + +Where `XXX_Config` is the Config class corresponding to the model. For example, the Config class for `Qwen2_5_VLModel` is `Qwen2_5_VLConfig`, which can be found by consulting its source code. The content inside Config can usually be found in the `config.json` file in the model library. `DiffSynth-Studio` will not read the `config.json` file, so the content needs to be copied and pasted into the code. + +In rare cases, version updates of `transformers` and `diffusers` may cause some models to be unable to import. Therefore, if possible, we still recommend using the model integration method in Step 1.1. + +In our existing models, Qwen-Image's Text Encoder is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_text_encoder.py`. + +
+ +## Step 2: Model File Format Conversion + +Due to the variety of model file formats provided by developers in the open-source community, we sometimes need to convert model file formats to form correctly formatted [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html). This is common in the following situations: + +* Model files built by different code libraries, for example [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) and [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers). +* Models modified during integration, for example, the Text Encoder of [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) adds a `model.` prefix in `diffsynth/models/qwen_image_text_encoder.py`. +* Model files containing multiple models, for example, the VACE Adapter and base DiT model of [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) are mixed and stored in the same set of model files. + +In our development philosophy, we hope to respect the wishes of model authors as much as possible. If we repackage the model files, for example [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI), although we can call the model more conveniently, traffic (model page views and downloads, etc.) will be directed elsewhere, and the original author of the model will also lose the power to delete the model. Therefore, we have added the `diffsynth/utils/state_dict_converters` module to the framework for file format conversion during model loading. + +This part of logic is very simple. Taking Qwen-Image's Text Encoder as an example, only 10 lines of code are needed: + +```python +def QwenImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for k in state_dict: + v = state_dict[k] + if k.startswith("visual."): + k = "model." + k + elif k.startswith("model."): + k = k.replace("model.", "model.language_model.") + state_dict_[k] = v + return state_dict_ +``` + +## Step 3: Writing Model Config + +Model Config is located in `diffsynth/configs/model_configs.py`, used to identify model types and load them. The following fields need to be filled in: + +* `model_hash`: Model file hash value, which can be obtained through the `hash_model_file` function. This hash value is only related to the keys and tensor shapes in the model file's state dict, and is unrelated to other information in the file. +* `model_name`: Model name, used for `Pipeline` to identify the required model. If different structured models play the same role in `Pipeline`, the same `model_name` can be used. When integrating new models, just ensure that `model_name` is different from other existing functional models. The corresponding model is fetched through `model_name` in the `Pipeline`'s `from_pretrained`. +* `model_class`: Model architecture import path, pointing to the model architecture class implemented in Step 1, for example `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`. +* `state_dict_converter`: Optional parameter. If model file format conversion is needed, the import path of the model conversion logic needs to be filled in, for example `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`. +* `extra_kwargs`: Optional parameter. If additional parameters need to be passed when initializing the model, these parameters need to be filled in. For example, models [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) and [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) both adopt the `QwenImageBlockWiseControlNet` structure in `diffsynth/models/qwen_image_controlnet.py`, but the latter also needs additional configuration `additional_in_dim=4`. Therefore, this configuration information needs to be filled in the `extra_kwargs` field. + +We provide a piece of code to quickly understand how models are loaded through this configuration information: + +```python +from diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization +from diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder +from diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter +import torch + +model_hash = "8004730443f55db63092006dd9f7110e" +model_name = "qwen_image_text_encoder" +model_class = QwenImageTextEncoder +state_dict_converter = QwenImageTextEncoderStateDictConverter +extra_kwargs = {} + +model_path = [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", +] +if hash_model_file(model_path) == model_hash: + with skip_model_initialization(): + model = model_class(**extra_kwargs) + state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device="cuda") + state_dict = state_dict_converter(state_dict) + model.load_state_dict(state_dict, assign=True) + print("Done!") +``` + +> Q: The logic of the above code looks very simple, why is this part of code in `DiffSynth-Studio` extremely complex? +> +> A: Because we provide aggressive VRAM management functions that are coupled with the model loading logic, this leads to the complexity of the framework structure. We have tried our best to simplify the interface exposed to developers. + +The `model_hash` in `diffsynth/configs/model_configs.py` is not uniquely existing. Multiple models may exist in the same model file. For this situation, please use multiple model Configs to load each model separately, and write the corresponding `state_dict_converter` to separate the parameters required by each model. + +## Step 4: Verifying Whether the Model Can Be Recognized and Loaded + +After model integration, the following code can be used to verify whether the model can be correctly recognized and loaded. The following code will attempt to load the model into memory: + +```python +from diffsynth.models.model_loader import ModelPool + +model_pool = ModelPool() +model_pool.auto_load_model( + [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", + ], +) +``` + +If the model can be recognized and loaded, you will see the following output: + +``` +Loading models from: [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +] +Loaded model: { + "model_name": "qwen_image_text_encoder", + "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder", + "extra_kwargs": null +} +``` + +## Step 5: Writing Model VRAM Management Scheme + +`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) for details. \ No newline at end of file diff --git a/docs/en/Developer_Guide/Training_Diffusion_Models.md b/docs/en/Developer_Guide/Training_Diffusion_Models.md new file mode 100644 index 0000000000000000000000000000000000000000..3fc92fc90f77c12cabf9a516e7e938a9a77b9398 --- /dev/null +++ b/docs/en/Developer_Guide/Training_Diffusion_Models.md @@ -0,0 +1,66 @@ +# Integrating Model Training + +After [integrating models](/docs/en/Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality. + +## Training-Inference Consistent Pipeline Modification + +To ensure strict consistency between training and inference processes, we will use most of the inference code during training, but still need to make minor modifications. + +First, add extra logic during inference to switch the image-to-image/video-to-video logic based on the `scheduler` state. Taking Qwen-Image as an example: + +```python +class QwenImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} +``` + +Then, enable Gradient Checkpointing in `model_fn`, which will significantly reduce the VRAM required for training at the cost of computational speed. This is not mandatory, but we strongly recommend doing so. + +Taking Qwen-Image as an example, before modification: + +```python +text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, +) +``` + +After modification: + +```python +from ..core import gradient_checkpoint_forward + +text, image = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, +) +``` + +## Writing Training Scripts + +`DiffSynth-Studio` does not strictly encapsulate the training framework, but exposes the script content to developers. This approach makes it more convenient to modify training scripts to implement additional functions. Developers can refer to existing training scripts, such as `examples/qwen_image/model_training/train.py`, for modification to adapt to new model training. \ No newline at end of file diff --git a/docs/en/Model_Details/FLUX.md b/docs/en/Model_Details/FLUX.md new file mode 100644 index 0000000000000000000000000000000000000000..1120a34c78bb9a40fd0efcb45fb79b31267d0298 --- /dev/null +++ b/docs/en/Model_Details/FLUX.md @@ -0,0 +1,201 @@ +# FLUX + +![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d) + +FLUX is an image generation model series developed and open-sourced by Black Forest Labs. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). + +## Quick Start + +Run the following code to quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run. + +```python +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1, +) +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +image = pipe(prompt=prompt, seed=0) +image.save("image.jpg") +``` + +## Model Overview + +
+ +Model Lineage + +```mermaid +graph LR; + FLUX.1-Series-->black-forest-labs/FLUX.1-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev; + black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series; + FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta; + FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha; + FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler; + black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter; + black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev; + black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview; + black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit; + Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2; + Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2; +``` + +
+ +| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | - | +| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) | +| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) | +| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) | +| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | +| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) | +| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) | +| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) | +| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) | +| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) | +| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - | +| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - | +| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) | +| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) | +| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) | + +Special Training Scripts: + +* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) + +## Model Inference + +Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). + +Input parameters for `FluxImagePipeline` inference include: + +* `prompt`: Prompt describing the content appearing in the image. +* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 1. When set to a value greater than 1, CFG is enabled. +* `height`: Image height, must be a multiple of 16. +* `width`: Image width, must be a multiple of 16. +* `seed`: Random seed. Default is `None`, meaning completely random. +* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. +* `num_inference_steps`: Number of inference steps, default value is 30. +* `embedded_guidance`: Embedded guidance parameter, default value is 3.5. +* `t5_sequence_length`: Sequence length of the T5 text encoder, default is 512. +* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time. +* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`. +* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`. +* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. +* `controlnet_inputs`: ControlNet model inputs, type is `ControlNetInput` list. +* `ipadapter_images`: IP-Adapter model input image list. +* `ipadapter_scale`: Guidance strength of the IP-Adapter model. +* `infinityou_id_image`: InfiniteYou model input image. +* `infinityou_guidance`: Guidance strength of the InfiniteYou model. +* `kontext_images`: Kontext model input images. +* `eligen_entity_prompts`: EliGen partition control prompt list. +* `eligen_entity_masks`: EliGen partition control region mask image list. +* `eligen_enable_on_negative`: Whether to enable EliGen partition control on the negative side of CFG. +* `eligen_enable_inpaint`: Whether to enable EliGen partition control inpainting function. +* `lora_encoder_inputs`: LoRA encoder input image list. +* `lora_encoder_scale`: Guidance strength of the LoRA encoder. +* `step1x_reference_image`: Step1X model reference image. +* `flex_inpaint_image`: Flex model image to be inpainted. +* `flex_inpaint_mask`: Flex model inpainting mask. +* `flex_control_image`: Flex model control image. +* `flex_control_strength`: Flex model control strength. +* `flex_control_stop`: Flex model control stop timestep. +* `nexus_gen_reference_image`: Nexus-Gen model reference image. + +If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. + +## Model Training + +FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py), and the script parameters include: + +* General Training Parameters + * Dataset Basic Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each DataLoader. + * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., `controlnet_inputs` when training ControlNet models, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). + * Training Basic Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model. + * Output Configuration + * `--output_path`: Model saving path. + * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file. + * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models) + * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged. +* FLUX Specific Parameters + * `--tokenizer_1_path`: Path of the CLIP tokenizer, leave blank to automatically download from remote. + * `--tokenizer_2_path`: Path of the T5 tokenizer, leave blank to automatically download from remote. + * `--align_to_opensource_format`: Whether to align LoRA format to open-source format, only applicable to DiT's LoRA. + +We have built a sample image dataset for your testing. You can download this dataset with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file diff --git a/docs/en/Model_Details/FLUX2.md b/docs/en/Model_Details/FLUX2.md new file mode 100644 index 0000000000000000000000000000000000000000..fd5e56d3e15a68325e5879196e3463823673af53 --- /dev/null +++ b/docs/en/Model_Details/FLUX2.md @@ -0,0 +1,138 @@ +# FLUX.2 + +FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). + +## Quick Start + +Run the following code to quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 10GB VRAM is required to run. + +```python +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image.jpg") +``` + +## Model Overview + +| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | +| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) | + +Special Training Scripts: + +* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/) +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) + +## Model Inference + +Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). + +Input parameters for `Flux2ImagePipeline` inference include: + +* `prompt`: Prompt describing the content appearing in the image. +* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 1. When set to a value greater than 1, CFG is enabled. +* `height`: Image height, must be a multiple of 16. +* `width`: Image width, must be a multiple of 16. +* `seed`: Random seed. Default is `None`, meaning completely random. +* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. +* `num_inference_steps`: Number of inference steps, default value is 30. +* `embedded_guidance`: Embedded guidance parameter, default value is 3.5. +* `t5_sequence_length`: Sequence length of the T5 text encoder, default is 512. +* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time. +* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`. +* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`. +* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. + +If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. + +## Model Training + +FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py), and the script parameters include: + +* General Training Parameters + * Dataset Basic Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each DataLoader. + * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., `controlnet_inputs` when training ControlNet models, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). + * Training Basic Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model. + * Output Configuration + * `--output_path`: Model saving path. + * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file. + * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models) + * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged. +* FLUX.2 Specific Parameters + * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote. + +We have built a sample image dataset for your testing. You can download this dataset with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file diff --git a/docs/en/Model_Details/Overview.md b/docs/en/Model_Details/Overview.md new file mode 100644 index 0000000000000000000000000000000000000000..5df859302de70667e39439cbed2d2d7e3758d1fd --- /dev/null +++ b/docs/en/Model_Details/Overview.md @@ -0,0 +1,291 @@ +# Model Directory + +## Qwen-Image + +Documentation: [./Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md) + +
+ +Effect Preview + +![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924) + +
+ +
+ +Quick Start + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from PIL import Image +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, num_inference_steps=40, + # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit +) +image.save("image.jpg") +``` + +
+ +
+ +Model Lineage + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | +| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | +| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | +| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | +| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | + +## FLUX Series + +Documentation: [./FLUX.md](/docs/en/Model_Details/FLUX.md) + +
+ +Effect Preview + +![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d) + +
+ +
+ +Quick Start + +```python +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image = pipe(prompt="a cat", seed=0) +image.save("image.jpg") +``` + +
+ +
+ +Model Lineage + +```mermaid +graph LR; + FLUX.1-Series-->black-forest-labs/FLUX.1-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev; + black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series; + FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta; + FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha; + FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler; + black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter; + black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev; + black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview; + black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit; + Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2; + Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2; +``` + +
+ +| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | - | +| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) | +| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) | +| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) | +| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | +| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) | +| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) | +| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) | +| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) | +| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) | +| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - | +| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - | +| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) | +| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) | +| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) | + +## Wan Series + +Documentation: [./Wan.md](/docs/en/Model_Details/Wan.md) + +
+ +Effect Preview + +https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 + +
+ +
+ +Quick Start + +```python +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` + +
+ +
+ +Model Lineage + +```mermaid +graph LR; + Wan-Series-->Wan2.1-Series; + Wan-Series-->Wan2.2-Series; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P; + Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P; + Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview; + iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera; + Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1; + Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video; + Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B; + Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B; + Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series; + Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera; +``` + +
+ +| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) | +| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) | +| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) | +| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) | +| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) | +| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) | +| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) | +| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) | +| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) | +| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | +| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) | +| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) | +| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) | +| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) | +| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) | +| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) | +| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) | +| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) | +| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) | +| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) | +| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | +| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | +| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | + +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/) \ No newline at end of file diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md new file mode 100644 index 0000000000000000000000000000000000000000..2f6a0dcf529e91044e3fac19a1bec2a52d1235d6 --- /dev/null +++ b/docs/en/Model_Details/Qwen-Image.md @@ -0,0 +1,192 @@ +# Qwen-Image + +![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924) + +Qwen-Image is an image generation model trained and open-sourced by the Tongyi Lab Qwen Team of Alibaba. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). + +## Quick Start + +Run the following code to quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## Model Overview + +
+ +Model Lineage + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) | +| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | +| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | +| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) | +| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) | +| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | +| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) | +| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - | +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| + +Special Training Scripts: + +* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/qwen_image/model_training/special/differential_training/) +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/qwen_image/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) + +## Model Inference + +Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). + +Input parameters for `QwenImagePipeline` inference include: + +* `prompt`: Prompt describing the content appearing in the image. +* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 4. When set to 1, it no longer takes effect. +* `input_image`: Input image for image-to-image generation, used in conjunction with `denoising_strength`. +* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated image is similar to the input image; when the value approaches 1, the generated image differs more from the input image. When `input_image` parameter is not provided, do not set this to a non-1 value. +* `inpaint_mask`: Image inpainting mask image. +* `inpaint_blur_size`: Edge blur width for image inpainting. +* `inpaint_blur_sigma`: Edge blur strength for image inpainting. +* `height`: Image height, must be a multiple of 16. +* `width`: Image width, must be a multiple of 16. +* `seed`: Random seed. Default is `None`, meaning completely random. +* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. +* `num_inference_steps`: Number of inference steps, default value is 30. +* `exponential_shift_mu`: Fixed parameter used in sampling timesteps. Leave blank to sample based on image width and height. +* `blockwise_controlnet_inputs`: Blockwise ControlNet model inputs. +* `eligen_entity_prompts`: EliGen partition control prompts. +* `eligen_entity_masks`: EliGen partition control region mask images. +* `eligen_enable_on_negative`: Whether to enable EliGen partition control on the negative side of CFG. +* `edit_image`: Edit model images to be edited, supports multiple images. +* `edit_image_auto_resize`: Whether to automatically scale edit images. +* `edit_rope_interpolation`: Whether to enable ROPE interpolation on low-resolution edit images. +* `context_image`: In-Context Control input image. +* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time. +* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`. +* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`. +* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. + +If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. + +## Model Training + +Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py), and the script parameters include: + +* General Training Parameters + * Dataset Basic Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each DataLoader. + * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters `edit_image` when training image editing model Qwen-Image-Edit, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). + * Training Basic Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model. + * Output Configuration + * `--output_path`: Model saving path. + * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file. + * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models) + * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged. +* Qwen-Image Specific Parameters + * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote. + * `--processor_path`: Path of the processor, applicable to image editing models, leave blank to automatically download from remote. + +We have built a sample image dataset for your testing. You can download this dataset with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file diff --git a/docs/en/Model_Details/Wan.md b/docs/en/Model_Details/Wan.md new file mode 100644 index 0000000000000000000000000000000000000000..83141bffef134edae3e3a35183e3d96412b5a65e --- /dev/null +++ b/docs/en/Model_Details/Wan.md @@ -0,0 +1,252 @@ +# Wan + +https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 + +Wan is a video generation model series developed by the Tongyi Wanxiang Team of Alibaba Tongyi Lab. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). + +## Quick Start + +Run the following code to quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run. + +```python +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` + +## Model Overview + +
+ +Model Lineage + +```mermaid +graph LR; + Wan-Series-->Wan2.1-Series; + Wan-Series-->Wan2.2-Series; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P; + Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P; + Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview; + iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera; + Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1; + Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video; + Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B; + Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B; + Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series; + Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera; +``` + +
+ +| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) | +| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) | +| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) | +| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) | +| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) | +| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) | +| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) | +| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) | +| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) | +| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) | +| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | +| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | +| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) | +| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) | +| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) | +| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) | +| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) | +| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) | +| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) | +| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) | +| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) | +| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) | +| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | +| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | +| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | + +* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/) +* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/) +* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/) + +## Model Inference + +Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). + +Input parameters for `WanVideoPipeline` inference include: + +* `prompt`: Prompt describing the content appearing in the video. +* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 5. When set to 1, it no longer takes effect. +* `input_image`: Input image for image-to-video generation, used in conjunction with `denoising_strength`. +* `end_image`: End image for first-and-last frame video generation. +* `input_video`: Input video for video-to-video generation, used in conjunction with `denoising_strength`. +* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated video is similar to the input video; when the value approaches 1, the generated video differs more from the input video. +* `control_video`: Control video for controlling the video generation process. +* `reference_image`: Reference image for maintaining consistency of certain features in the generated video. +* `camera_control_direction`: Camera control direction, optional values are `"Left"`, `"Right"`, `"Up"`, `"Down"`, `"LeftUp"`, `"LeftDown"`, `"RightUp"`, `"RightDown"`. +* `camera_control_speed`: Camera control speed, default value is 1/54. +* `vace_video`: VACE control video. +* `vace_video_mask`: VACE control video mask. +* `vace_reference_image`: VACE reference image. +* `vace_scale`: VACE control strength, default value is 1.0. +* `animate_pose_video`: `animate` model pose video. +* `animate_face_video`: `animate` model face video. +* `animate_inpaint_video`: `animate` model local editing video. +* `animate_mask_video`: `animate` model mask video. +* `vap_video`: `video-as-prompt` input video. +* `vap_prompt`: `video-as-prompt` text description. +* `negative_vap_prompt`: `video-as-prompt` negative text description. +* `input_audio`: Input audio for speech-to-video generation. +* `audio_embeds`: Audio embedding vectors. +* `audio_sample_rate`: Audio sampling rate, default value is 16000. +* `s2v_pose_video`: S2V model pose video. +* `motion_video`: S2V model motion video. +* `height`: Video height, must be a multiple of 16. +* `width`: Video width, must be a multiple of 16. +* `num_frames`: Number of video frames, default value is 81, must be a multiple of 4 + 1. +* `seed`: Random seed. Default is `None`, meaning completely random. +* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. +* `num_inference_steps`: Number of inference steps, default value is 50. +* `motion_bucket_id`: Motion control parameter, the larger the value, the greater the motion amplitude. +* `longcat_video`: LongCat input video. +* `tiled`: Whether to enable VAE tiling inference, default is `True`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time. +* `tile_size`: Tile size during VAE encoding/decoding stages, default is `(30, 52)`, only effective when `tiled=True`. +* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is `(15, 26)`, only effective when `tiled=True`, must be less than or equal to `tile_size`. +* `switch_DiT_boundary`: Time boundary for switching DiT models, default value is 0.875. +* `sigma_shift`: Timestep offset parameter, default value is 5.0. +* `sliding_window_size`: Sliding window size. +* `sliding_window_stride`: Sliding window stride. +* `tea_cache_l1_thresh`: L1 threshold for TeaCache. +* `tea_cache_model_id`: Model ID used by TeaCache. +* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`. + +If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. + +## Model Training + +Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py), and the script parameters include: + +* General Training Parameters + * Dataset Basic Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each DataLoader. + * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). + * Training Basic Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model. + * Output Configuration + * `--output_path`: Model saving path. + * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file. + * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Video Width/Height Configuration + * `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged. + * `--num_frames`: Number of frames in the video. +* Wan Series Specific Parameters + * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote. + * `--audio_processor_path`: Path of the audio processor, applicable to speech-to-video models, leave blank to automatically download from remote. + +We have built a sample video dataset for your testing. You can download this dataset with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset +``` + +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). \ No newline at end of file diff --git a/docs/en/Model_Details/Z-Image.md b/docs/en/Model_Details/Z-Image.md new file mode 100644 index 0000000000000000000000000000000000000000..7490b2f12c904ff07a7d47f73026f0401ff8143b --- /dev/null +++ b/docs/en/Model_Details/Z-Image.md @@ -0,0 +1,141 @@ +# Z-Image + +Z-Image is an image generation model trained and open-sourced by the Multimodal Interaction Team of Alibaba Tongyi Lab. + +## Installation + +Before using this project for model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md). + +## Quick Start + +Run the following code to quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model and perform inference. FP8 precision quantization causes noticeable image quality degradation, so it is not recommended to enable any quantization on the Z-Image Turbo model. Only CPU Offload is recommended, minimum 8GB VRAM is required to run. + +```python +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image.jpg") +``` + +## Model Overview + +| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | +| - | - | - | - | - | - | - | +| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) | + +Special Training Scripts: + +* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/z_image/model_training/special/differential_training/) +* Trajectory Imitation Distillation Training (Experimental Feature): [code](/examples/z_image/model_training/special/trajectory_imitation/) + +## Model Inference + +Models are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). + +Input parameters for `ZImagePipeline` inference include: + +* `prompt`: Prompt describing the content appearing in the image. +* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`. +* `cfg_scale`: Classifier-free guidance parameter, default value is 1. +* `input_image`: Input image for image-to-image generation, used in conjunction with `denoising_strength`. +* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated image is similar to the input image; when the value approaches 1, the generated image differs more from the input image. When `input_image` parameter is not provided, do not set this to a non-1 value. +* `height`: Image height, must be a multiple of 16. +* `width`: Image width, must be a multiple of 16. +* `seed`: Random seed. Default is `None`, meaning completely random. +* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results. +* `num_inference_steps`: Number of inference steps, default value is 8. + +If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. + +## Model Training + +Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py), and the script parameters include: + +* General Training Parameters + * Dataset Basic Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each DataLoader. + * `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). + * Training Basic Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). + * `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model. + * Output Configuration + * `--output_path`: Model saving path. + * `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file. + * `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models) + * `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged. +* Z-Image Specific Parameters + * `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote. + +We have built a sample image dataset for your testing. You can download this dataset with the following command: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/). + +Training Tips: + +* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with "acceleration configuration" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with "no acceleration configuration" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted: + * Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference + * Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference + * An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter) + * Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference + * Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + Acceleration Configuration Inference \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Environment_Variables.md b/docs/en/Pipeline_Usage/Environment_Variables.md new file mode 100644 index 0000000000000000000000000000000000000000..91d016fffa33b9ba7e3939567ea802321f416ccc --- /dev/null +++ b/docs/en/Pipeline_Usage/Environment_Variables.md @@ -0,0 +1,39 @@ +# Environment Variables + +`DiffSynth-Studio` can control some settings through environment variables. + +In `Python` code, you can set environment variables using `os.environ`. Please note that environment variables must be set before `import diffsynth`. + +```python +import os +os.environ["DIFFSYNTH_MODEL_BASE_PATH"] = "./path_to_my_models" +import diffsynth +``` + +On Linux operating systems, you can also temporarily set environment variables from the command line: + +```shell +DIFFSYNTH_MODEL_BASE_PATH="./path_to_my_models" python xxx.py +``` + +Below are the environment variables supported by `DiffSynth-Studio`. + +## `DIFFSYNTH_SKIP_DOWNLOAD` + +Whether to skip model downloads. Can be set to `True`, `true`, `False`, `false`. If `skip_download` is not set in `ModelConfig`, this environment variable will determine whether to skip model downloads. + +## `DIFFSYNTH_MODEL_BASE_PATH` + +Model download root directory. Can be set to any local path. If `local_model_path` is not set in `ModelConfig`, model files will be downloaded to the path pointed to by this environment variable. If neither is set, model files will be downloaded to `./models`. + +## `DIFFSYNTH_ATTENTION_IMPLEMENTATION` + +Attention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](/docs/en/API_Reference/core/attention.md) for details. + +## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE` + +Buffer size in disk mapping. Default is 1B (1000000000). Larger values occupy more memory but result in faster speeds. + +## `DIFFSYNTH_DOWNLOAD_SOURCE` + +Remote model download source. Can be set to `modelscope` or `huggingface` to control the source of model downloads. Default value is `modelscope`. \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Model_Inference.md b/docs/en/Pipeline_Usage/Model_Inference.md new file mode 100644 index 0000000000000000000000000000000000000000..e5a85a071fba5d36b496743aa24bfe7dae048d9f --- /dev/null +++ b/docs/en/Pipeline_Usage/Model_Inference.md @@ -0,0 +1,105 @@ +# Model Inference + +This document uses the Qwen-Image model as an example to introduce how to use `DiffSynth-Studio` for model inference. + +## Loading Models + +Models are loaded through `from_pretrained`: + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +``` + +Where `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md). + +
+ +Download and load models from remote sources + +> `DiffSynth-Studio` downloads and loads models from [ModelScope](https://www.modelscope.cn/) by default. You need to fill in `model_id` and `origin_file_pattern`, for example: +> +> ```python +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), +> ``` +> +> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). + +
+ +
+ +Load models from local file paths + +> Fill in `path`, for example: +> +> ```python +> ModelConfig(path="models/xxx.safetensors") +> ``` +> +> For models loaded from multiple files, use a list, for example: +> +> ```python +> ModelConfig(path=[ +> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", +> ]) +> ``` + +
+ +By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. + +```shell +import os +os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True" +import diffsynth +``` + +To download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`. + +```shell +import os +os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface" +import diffsynth +``` + +## Starting Inference + +Input a prompt to start the inference process and generate an image. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +Each model `Pipeline` has different input parameters. Please refer to the documentation for each model. + +If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md). \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Model_Training.md b/docs/en/Pipeline_Usage/Model_Training.md new file mode 100644 index 0000000000000000000000000000000000000000..3c5bffd40212b5d43e229c10501db3b965ee7751 --- /dev/null +++ b/docs/en/Pipeline_Usage/Model_Training.md @@ -0,0 +1,247 @@ +# Model Training + +This document introduces how to use `DiffSynth-Studio` for model training. + +## Script Parameters + +Training scripts typically include the following parameters: + +* Dataset base configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Metadata file path of the dataset. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. + * `--dataset_num_workers`: Number of processes for each Dataloader. + * `--data_file_keys`: Field names that need to be loaded from metadata, usually image or video file paths, separated by `,`. +* Model loading configuration + * `--model_paths`: Paths of models to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, for example `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`. Separated by commas. + * `--extra_inputs`: Extra input parameters required by the model Pipeline, for example, training image editing model Qwen-Image-Edit requires extra parameter `edit_image`, separated by `,`. + * `--fp8_models`: Models loaded in FP8 format, consistent with the format of `--model_paths` or `--model_id_with_origin_paths`. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA). +* Training base configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, for example `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training. + * `--weight_decay`: Weight decay size. See [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html) for details. + * `--task`: Training task, default is `sft`. Some models support more training modes. Please refer to the documentation for each specific model. +* Output configuration + * `--output_path`: Model save path. + * `--remove_prefix_in_ckpt`: Remove prefixes in the state dict of model files. + * `--save_steps`: Interval of training steps for saving models. If this parameter is left blank, the model will be saved once per epoch. +* LoRA configuration + * `--lora_base_model`: Which model LoRA is added to. + * `--lora_target_modules`: Which layers LoRA is added to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path of LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint. + * `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training. + * `--preset_lora_model`: Model that preset LoRA is merged into, for example `dit`. +* Gradient configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. +* Image dimension configuration (applicable to image generation models and video generation models) + * `--height`: Height of images or videos. Leave `height` and `width` blank to enable dynamic resolution. + * `--width`: Width of images or videos. Leave `height` and `width` blank to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area of images or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be scaled down, and images with resolution smaller than this value will remain unchanged. + +Some models' training scripts also contain additional parameters. See the documentation for each model for details. + +## Preparing Datasets + +`DiffSynth-Studio` adopts a universal dataset format. The dataset contains a series of data files (images, videos, etc.) and annotated metadata files. We recommend organizing dataset files as follows: + +``` +data/example_image_dataset/ +├── metadata.csv +├── image_1.jpg +└── image_2.jpg +``` + +Where `image_1.jpg`, `image_2.jpg` are training image data, and `metadata.csv` is the metadata list, for example: + +``` +image,prompt +image_1.jpg,"a dog" +image_2.jpg,"a cat" +``` + +We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md). + +
+ +Sample Image Dataset + +> ```shell +> modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +> ``` +> +> Applicable to training of image generation models such as Qwen-Image and FLUX. + +
+ +
+ +Sample Video Dataset + +> ```shell +> modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset +> ``` +> +> Applicable to training of video generation models such as Wan. + +
+ +## Loading Models + +Similar to [model loading during inference](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed. + +
+ +Download and load models from remote sources + +> If we load models during inference through the following settings: +> +> ```python +> model_configs=[ +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), +> ] +> ``` +> +> Then during training, fill in the following parameters to load the corresponding models: +> +> ```shell +> --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" +> ``` +> +> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path). +> +> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`. + +
+ +
+ +
+ +Load models from local file paths + +> If loading models from local files during inference, for example: +> +> ```python +> model_configs=[ +> ModelConfig([ +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors" +> ]), +> ModelConfig([ +> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +> ]), +> ModelConfig("models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors") +> ] +> ``` +> +> Then during training, set to: +> +> ```shell +> --model_paths '[ +> [ +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors" +> ], +> [ +> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +> ], +> "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors" +> ]' \ +> ``` +> +> Note that `--model_paths` is in JSON format, and extra `,` cannot appear in it, otherwise it cannot be parsed normally. + +
+ +## Setting Trainable Modules + +The training framework supports training of any model. Taking Qwen-Image as an example, to fully train the DiT model, set to: + +```shell +--trainable_models "dit" +``` + +To train LoRA of the DiT model, set to: + +```shell +--lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32 +``` + +We hope to leave enough room for technical exploration, so the framework supports training any number of modules simultaneously. For example, to train the text encoder, controlnet, and LoRA of the DiT simultaneously: + +```shell +--trainable_models "text_encoder,controlnet" --lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32 +``` + +Additionally, since the training script loads multiple modules (text encoder, dit, vae, etc.), prefixes need to be removed when saving model files. For example, when fully training the DiT part or training the LoRA model of the DiT part, please set `--remove_prefix_in_ckpt pipe.dit.`. If multiple modules are trained simultaneously, developers need to write code to split the state dict in the model file after training is completed. + +## Starting the Training Program + +The training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Training commands are written in the following format: + +```shell +accelerate launch xxx/train.py \ + --xxx yyy \ + --xxxx yyyy +``` + +We have written preset training scripts for each model. See the documentation for each model for details. + +By default, `accelerate` will train according to the configuration in `~/.cache/huggingface/accelerate/default_config.yaml`. Use `accelerate config` to configure interactively in the terminal, including multi-GPU training, [`DeepSpeed`](https://www.deepspeed.ai/), etc. + +We provide recommended `accelerate` configuration files for some models, which can be set through `--config_file`. For example, full training of the Qwen-Image model: + +```shell +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters +``` + +## Training Considerations + +* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [/docs/en/API_Reference/core/data.md#metadata](/docs/en/API_Reference/core/data.md#metadata) +* Training effectiveness is usually strongly correlated with training steps and weakly correlated with epoch count. Therefore, we recommend using the `--save_steps` parameter to save model files at training step intervals. +* When data volume * `dataset_repeat` exceeds $10^9$, we observed that the dataset speed becomes significantly slower, which seems to be a `PyTorch` bug. We are not sure if newer versions of `PyTorch` have fixed this issue. +* For learning rate `--learning_rate`, it is recommended to set to `1e-4` in LoRA training and `1e-5` in full training. +* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](/docs/en/QA.md#why-doesnt-the-training-framework-support-batch-size--1) +* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters. +* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges. +* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md) for details. \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/Setup.md b/docs/en/Pipeline_Usage/Setup.md new file mode 100644 index 0000000000000000000000000000000000000000..c9fba68c3316efe335424f9f82d17329d302e54f --- /dev/null +++ b/docs/en/Pipeline_Usage/Setup.md @@ -0,0 +1,21 @@ +# Installing Dependencies + +Install from source (recommended): + +``` +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +Install from PyPI (there may be delays in version updates; for latest features, install from source): + +``` +pip install diffsynth +``` + +If you encounter issues during installation, they may be caused by upstream dependency packages. Please refer to the documentation for these packages: + +* [torch](https://pytorch.org/get-started/locally/) +* [sentencepiece](https://github.com/google/sentencepiece) +* [cmake](https://cmake.org) \ No newline at end of file diff --git a/docs/en/Pipeline_Usage/VRAM_management.md b/docs/en/Pipeline_Usage/VRAM_management.md new file mode 100644 index 0000000000000000000000000000000000000000..ecf5379e4aa06948613426b59b52214dc7b3e40b --- /dev/null +++ b/docs/en/Pipeline_Usage/VRAM_management.md @@ -0,0 +1,206 @@ +# VRAM Management + +VRAM management is a distinctive feature of `DiffSynth-Studio` that enables GPUs with low VRAM to run inference with large parameter models. This document uses Qwen-Image as an example to introduce how to use the VRAM management solution. + +## Basic Inference + +The following code does not enable any VRAM management, occupying 56G VRAM as a reference. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## CPU Offload + +Since the model `Pipeline` consists of multiple components that are not called simultaneously, we can move some components to memory when they are not needed for computation, reducing VRAM usage. The following code implements this logic, occupying 40G VRAM. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## FP8 Quantization + +Building upon CPU Offload, we further enable FP8 quantization to reduce VRAM requirements. The following code allows model parameters to be stored in VRAM with FP8 precision and temporarily converted to BF16 precision for computation during inference, occupying 21G VRAM. However, this quantization scheme has minor image quality degradation issues. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cuda", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +> Q: Why temporarily convert to BF16 precision during inference instead of computing with FP8 precision? +> +> A: Native FP8 computation is only supported on Hopper architecture GPUs (such as H20) and has significant computational errors. We currently do not enable FP8 precision computation. The current FP8 quantization only reduces VRAM usage but does not improve computation speed. + +## Dynamic VRAM Management + +In CPU Offload, we control model components. In fact, we support Layer-level Offload, splitting a model into multiple Layers, keeping some resident in VRAM and storing others in memory for on-demand transfer to VRAM for computation. This feature requires model developers to provide detailed VRAM management solutions for each model. Related configurations are in `diffsynth/configs/vram_management_module_maps.py`. + +By adding the `vram_limit` parameter to the `Pipeline`, the framework can automatically sense the remaining VRAM of the device and decide how to split the model between VRAM and memory. The smaller the `vram_limit`, the less VRAM occupied, but slower the speed. +* When `vram_limit=None`, the default state, the framework assumes unlimited VRAM and dynamic VRAM management is disabled +* When `vram_limit=10`, the framework will limit the model after VRAM usage exceeds 10G, moving the excess parts to memory storage +* When `vram_limit=0`, the framework will do its best to reduce VRAM usage, storing all model parameters in memory and transferring them to VRAM for computation only when necessary + +When VRAM is insufficient to run model inference, the framework will attempt to exceed the `vram_limit` restriction to keep the model inference running. Therefore, the VRAM management framework cannot always guarantee that VRAM usage will be less than `vram_limit`. We recommend setting it to slightly less than the actual available VRAM. For example, when GPU VRAM is 16G, set it to `vram_limit=15.5`. In `PyTorch`, you can use `torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3)` to get the GPU's VRAM. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## Disk Offload + +In more extreme cases, when memory is also insufficient to store the entire model, the Disk Offload feature allows lazy loading of model parameters, meaning each Layer of the model only reads the corresponding parameters from disk when the forward function is called. When enabling this feature, we recommend using high-speed SSD drives. + +Disk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape. + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=10, +) +prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal." +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## More Usage Methods + +Information in `vram_config` can be filled in manually, for example, Disk Offload without FP8 quantization: + +```python +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +``` + +Specifically, the VRAM management module divides model Layers into the following four states: + +* Offload: This model will not be called in the short term. This state is controlled by switching `Pipeline` +* Onload: This model will be called at any time soon. This state is controlled by switching `Pipeline` +* Preparing: Intermediate state between Onload and Computation. A temporary storage state when VRAM allows. This state is controlled by the VRAM management mechanism and enters this state if and only if [vram_limit is set to unlimited] or [vram_limit is set and there is spare VRAM] +* Computation: The model is being computed. This state is controlled by the VRAM management mechanism and is temporarily entered only during `forward` + +If you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](/docs/en/Developer_Guide/Enabling_VRAM_management.md). + +## Best Practices + +* Sufficient VRAM -> Use [Basic Inference](#basic-inference) +* Insufficient VRAM + * Sufficient memory -> Use [Dynamic VRAM Management](#dynamic-vram-management) + * Insufficient memory -> Use [Disk Offload](#disk-offload) \ No newline at end of file diff --git a/docs/en/QA.md b/docs/en/QA.md new file mode 100644 index 0000000000000000000000000000000000000000..fe7546025bbfc9d5e8e336ea7dc195b081ab958b --- /dev/null +++ b/docs/en/QA.md @@ -0,0 +1,28 @@ +# Frequently Asked Questions + +## Why doesn't the training framework support batch size > 1? + +* **Larger batch sizes no longer achieve significant acceleration**: Due to acceleration technologies such as flash attention that have fully improved GPU utilization, larger batch sizes will only bring greater VRAM usage without significant acceleration. The experience with small models like Stable Diffusion 1.5 is no longer applicable to the latest large models. +* **Larger batch sizes can be achieved through other solutions**: Multi-GPU training and Gradient Accumulation can both mathematically equivalently achieve larger batch sizes. +* **Larger batch sizes contradict the framework's general design**: We hope to build a general training framework. Many models cannot accommodate larger batch sizes, such as text encodings of different lengths and images of different resolutions, which cannot be merged into larger batches. + +## Why aren't redundant parameters removed from certain models? + +In some models, redundant parameters exist. For example, in Qwen-Image's DiT model, the text portion of the last layer does not participate in any calculations. This is a minor bug left by the model developers. Setting it as trainable directly will also cause errors in multi-GPU training. + +To maintain compatibility with other models in the open-source community, we have decided to retain these parameters. These redundant parameters can avoid errors in multi-GPU training through the `--find_unused_parameters` parameter. + +## Why does FP8 quantization show no acceleration effect? + +Native FP8 computation relies on Hopper architecture GPUs and has significant precision errors. It is currently immature technology, so this project does not support native FP8 computation. + +FP8 computation in VRAM management refers to storing model parameters in memory or VRAM with FP8 precision and temporarily converting them to other precisions when needed for computation. Therefore, it can only reduce VRAM usage without acceleration effects. + +## Why doesn't the training framework support native FP8 precision training? + +Even with suitable hardware conditions, we currently have no plans to support native FP8 precision training. + +* The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present. +* Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8. + +Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community. \ No newline at end of file diff --git a/docs/en/README.md b/docs/en/README.md new file mode 100644 index 0000000000000000000000000000000000000000..17c8e8892b403cff938e6d360be00278ca9af6f8 --- /dev/null +++ b/docs/en/README.md @@ -0,0 +1,88 @@ +# DiffSynth-Studio Documentation + +Welcome to the magical world of Diffusion models! `DiffSynth-Studio` is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We aim to build a universal Diffusion model framework that fosters technological innovation through framework construction, aggregates the power of the open-source community, and explores the boundaries of generative model technology! + +
+ +Documentation Reading Guide + +```mermaid +graph LR; + I_want_to_use_models_for_inference_and_training-->sec1[Section 1: Getting Started]; + I_want_to_use_models_for_inference_and_training-->sec2[Section 2: Model Details]; + I_want_to_use_models_for_inference_and_training-->sec3[Section 3: Training Framework]; + I_want_to_develop_based_on_this_framework-->sec3[Section 3: Training Framework]; + I_want_to_develop_based_on_this_framework-->sec4[Section 4: Model Integration]; + I_want_to_develop_based_on_this_framework-->sec5[Section 5: API Reference]; + I_want_to_explore_new_technologies_based_on_this_project-->sec4[Section 4: Model Integration]; + I_want_to_explore_new_technologies_based_on_this_project-->sec5[Section 5: API Reference]; + I_want_to_explore_new_technologies_based_on_this_project-->sec6[Section 6: Academic Guide]; + I_encountered_a_problem-->sec7[Section 7: Frequently Asked Questions]; +``` + +
+ +## Section 1: Getting Started + +This section introduces the basic usage of `DiffSynth-Studio`, including how to enable VRAM management for inference on GPUs with extremely low VRAM, and how to train various base models, LoRAs, ControlNets, and other models. + +* [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md) +* [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) +* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) +* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md) +* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md) + +## Section 2: Model Details + +This section introduces the Diffusion models supported by `DiffSynth-Studio`. Some model pipelines feature special functionalities such as controllable generation and parallel acceleration. + +* [FLUX.1](/docs/en/Model_Details/FLUX.md) +* [Wan](/docs/en/Model_Details/Wan.md) +* [Qwen-Image](/docs/en/Model_Details/Qwen-Image.md) +* [FLUX.2](/docs/en/Model_Details/FLUX2.md) +* [Z-Image](/docs/en/Model_Details/Z-Image.md) + +## Section 3: Training Framework + +This section introduces the design philosophy of the training framework in `DiffSynth-Studio`, helping developers understand the principles of Diffusion model training algorithms. + +* [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md) +* [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) +* [Enabling FP8 Precision in Training](/docs/en/Training/FP8_Precision.md) +* [End-to-End Distillation Accelerated Training](/docs/en/Training/Direct_Distill.md) +* [Two-Stage Split Training](/docs/en/Training/Split_Training.md) +* [Differential LoRA Training](/docs/en/Training/Differential_LoRA.md) + +## Section 4: Model Integration + +This section introduces how to integrate models into `DiffSynth-Studio` to utilize the framework's basic functions, helping developers provide support for new models in this project or perform inference and training of private models. + +* [Integrating Model Architecture](/docs/en/Developer_Guide/Integrating_Your_Model.md) +* [Building a Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md) +* [Enabling Fine-Grained VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) +* [Model Training Integration](/docs/en/Developer_Guide/Training_Diffusion_Models.md) + +## Section 5: API Reference + +This section introduces the independent core module `diffsynth.core` in `DiffSynth-Studio`, explaining how internal functions are designed and operate. Developers can use these functional modules in other codebase developments if needed. + +* [`diffsynth.core.attention`](/docs/en/API_Reference/core/attention.md): Attention mechanism implementation +* [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md): Data processing operators and general datasets +* [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md): Gradient checkpointing +* [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md): Model download and loading +* [`diffsynth.core.vram`](/docs/en/API_Reference/core/vram.md): VRAM management + +## Section 6: Academic Guide + +This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies. + +* Training models from scratch 【coming soon】 +* Inference improvement techniques 【coming soon】 +* Designing controllable generation models 【coming soon】 +* Creating new training paradigms 【coming soon】 + +## Section 7: Frequently Asked Questions + +This section summarizes common developer questions. If you encounter issues during usage or development, please refer to this section. If you still cannot resolve the problem, please submit an issue on GitHub. + +* [Frequently Asked Questions](/docs/en/QA.md) \ No newline at end of file diff --git a/docs/en/Training/Differential_LoRA.md b/docs/en/Training/Differential_LoRA.md new file mode 100644 index 0000000000000000000000000000000000000000..febe5076bae8e04bf69132938d8b663f5406c5dc --- /dev/null +++ b/docs/en/Training/Differential_LoRA.md @@ -0,0 +1,38 @@ +# Differential LoRA Training + +Differential LoRA training is a special form of LoRA training designed to enable models to learn differences between images. + +## Training Approach + +We were unable to identify the original proposer of differential LoRA training, as this technique has been circulating in the open-source community for a long time. + +Assume we have two similar-content images: Image 1 and Image 2. For example, both images contain a car, but Image 1 has fewer details while Image 2 has more details. In differential LoRA training, we perform two-step training: + +* Train LoRA 1 using Image 1 as training data with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md) +* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md) + +In the first training step, since there is only one training image, the LoRA model easily overfits. Therefore, after training, LoRA 1 will cause the model to generate Image 1 without hesitation, regardless of the random seed. In the second training step, the LoRA model overfits again. Thus, after training, with the combined effect of LoRA 1 and LoRA 2, the model will generate Image 2 without hesitation. In short: + +* LoRA 1 = Generate Image 1 +* LoRA 1 + LoRA 2 = Generate Image 2 + +At this point, discarding LoRA 1 and using only LoRA 2, the model will understand the difference between Image 1 and Image 2, making the generated content tend toward "less like Image 1, more like Image 2." + +Single training data can ensure the model overfits to the training data, but lacks stability. To improve stability, we can train with multiple image pairs and average the trained LoRA 2 models to obtain a more stable LoRA. + +Using this training approach, some functionally unique LoRA models can be trained. For example, using ugly and beautiful image pairs to train LoRAs that enhance image aesthetics; using low-detail and high-detail image pairs to train LoRAs that increase image detail. + +## Model Effects + +We have trained several aesthetic enhancement LoRAs using differential LoRA training techniques. You can visit the corresponding model pages to view the generation effects. + +* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1) +* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) + +## Using Differential LoRA Training in the Training Framework + +The first step of training is identical to ordinary LoRA training. In the second step's training command, fill in the path of the first step's LoRA model file through the `--preset_lora_path` parameter, and set `--preset_lora_model` to the same parameters as `lora_base_model` to load LoRA 1 into the base model. + +## Framework Design Concept + +In the training framework, the model pointed to by `--preset_lora_path` is loaded in the `switch_pipe_to_training_mode` of `DiffusionTrainingModule`. \ No newline at end of file diff --git a/docs/en/Training/Direct_Distill.md b/docs/en/Training/Direct_Distill.md new file mode 100644 index 0000000000000000000000000000000000000000..4cbeb59dbd26cbf5c29895b2c9cbb1a967d57252 --- /dev/null +++ b/docs/en/Training/Direct_Distill.md @@ -0,0 +1,97 @@ +# End-to-End Distillation Accelerated Training + +## Distillation Accelerated Training + +The inference process of Diffusion models typically requires multi-step iterations, which improves generation quality but also makes the generation process slow. Through distillation accelerated training, the number of steps required to generate clear content can be reduced. The essence of distillation accelerated training technology is to align the generation effects of a small number of steps with those of a large number of steps. + +There are diverse methods for distillation accelerated training, such as: + +* Adversarial training ADD (Adversarial Diffusion Distillation) + * Paper: https://arxiv.org/abs/2311.17042 + * Model: [stabilityai/sdxl-turbo](https://modelscope.cn/models/stabilityai/sdxl-turbo) +* Progressive training Hyper-SD + * Paper: https://arxiv.org/abs/2404.13686 + * Model: [ByteDance/Hyper-SD](https://www.modelscope.cn/models/ByteDance/Hyper-SD) + +## Direct Distillation + +At the framework level, supporting these distillation accelerated training schemes is extremely difficult. In the design of the training framework, we need to ensure that the training scheme meets the following conditions: + +* Generality: The training scheme applies to most Diffusion models supported within the framework, rather than only working for a specific model, which is a basic requirement for code framework construction. +* Stability: The training scheme must ensure stable training effects without requiring manual fine-tuning of parameters. Adversarial training in ADD cannot guarantee stability. +* Simplicity: The training scheme does not introduce additional complex modules. According to Occam's Razor principle, complex solutions may introduce potential risks. The Human Feedback Learning in Hyper-SD makes the training process overly complex. + +Therefore, in the training framework of `DiffSynth-Studio`, we designed an end-to-end distillation accelerated training scheme, which we call Direct Distillation. The pseudocode for the training process is as follows: + +``` +seed = xxx +with torch.no_grad(): + image_1 = pipe(prompt, steps=50, seed=seed, cfg=4) +image_2 = pipe(prompt, steps=4, seed=seed, cfg=1) +loss = torch.nn.functional.mse_loss(image_1, image_2) +``` + +Yes, it's a very end-to-end training scheme that produces immediate results with minimal training. + +## Models Trained with Direct Distillation + +We trained two models based on Qwen-Image using this scheme: + +* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): Full distillation training +* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA distillation training + +Click on the model links to go to the model pages and view the model effects. + +## Using Distillation Accelerated Training in the Training Framework + +First, you need to generate training data. Please refer to the [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps. + +Taking Qwen-Image as an example, the following code can generate an image: + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +Then, we compile the necessary information into [metadata files](/docs/en/API_Reference/core/data.md#metadata): + +```csv +image,prompt,seed,rand_device,num_inference_steps,cfg_scale +distill_qwen/image.jpg,"精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。",0,cpu,4,1 +``` + +This sample dataset can be downloaded directly: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +Then start LoRA distillation accelerated training: + +```shell +bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh +``` + +Please note that in the [training script parameters](/docs/en/Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image. + +## Framework Design Concept + +Compared to [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`. + +## Future Work + +Direct Distillation is a highly general acceleration scheme, but it may not be the best-performing scheme. Therefore, we have not yet published this technology in paper form. We hope to leave this problem to the academic and open-source communities to solve together, and we look forward to developers providing more complete general training schemes. \ No newline at end of file diff --git a/docs/en/Training/FP8_Precision.md b/docs/en/Training/FP8_Precision.md new file mode 100644 index 0000000000000000000000000000000000000000..5f23abbc90706a1b4c22a0dcf00aee158cd70b0b --- /dev/null +++ b/docs/en/Training/FP8_Precision.md @@ -0,0 +1,20 @@ +# Enabling FP8 Precision in Training + +Although `DiffSynth-Studio` supports [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes. + +FP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](/docs/en/QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision. + +## Enabling FP8 + +In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py). + +Please note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types: + +* Parameters are not trainable, such as VAE models +* Gradients do not update their parameters, such as DiT models in LoRA training + +Experimental verification shows that LoRA training with FP8 enabled does not cause significant image quality degradation. However, theoretical errors do exist. If you encounter training results inferior to BF16 precision training when using this feature, please provide feedback through GitHub issues. + +## Training Framework Design Concept + +The training framework completely reuses the inference VRAM management, and only parses VRAM management configurations through `parse_model_configs` in `DiffusionTrainingModule` during training. \ No newline at end of file diff --git a/docs/en/Training/Split_Training.md b/docs/en/Training/Split_Training.md new file mode 100644 index 0000000000000000000000000000000000000000..07068d20287f7ce64796deb0d228eb25919d3aef --- /dev/null +++ b/docs/en/Training/Split_Training.md @@ -0,0 +1,97 @@ +# Two-Stage Split Training + +This document introduces split training, which can automatically divide the training process into two stages, reducing VRAM usage while accelerating training speed. + +(Split training is an experimental feature that has not yet undergone large-scale validation. If you encounter any issues while using it, please submit an issue on GitHub.) + +## Split Training + +In the training process of most models, a large amount of computation occurs in "preprocessing," i.e., "computations unrelated to the denoising model," including VAE encoding, text encoding, etc. When the corresponding model parameters are fixed, the results of these computations are repetitive. For each data sample, the computational results are identical across multiple epochs. Therefore, we provide a "split training" feature that can automatically analyze and split the training process. + +For standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](/docs/en/Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation. + +## Computational Graph Splitting Algorithm + +> (We will supplement the detailed specifics of the computational graph splitting algorithm in future document updates) + +## Using Split Training + +Split training already supports [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](/docs/en/Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is: + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters +``` + +After splitting, in the first stage, make the following modifications: + +* Change `--dataset_repeat` to 1 to avoid redundant computation +* Change `--output_path` to the path where the first-stage computation results are saved +* Add the additional parameter `--task "sft:data_process"` +* Remove the DiT model from `--model_id_with_origin_paths` + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:data_process" +``` + +In the second stage, make the following modifications: + +* Change `--dataset_base_path` to the `--output_path` of the first stage +* Remove `--dataset_metadata_path` +* Add the additional parameter `--task "sft:train"` +* Remove the Text Encoder and VAE models from `--model_id_with_origin_paths` + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:train" +``` + +We provide sample training scripts and validation scripts located at `examples/qwen_image/model_training/special/split_training`. + +## Training Framework Design Concept + +The training framework splits the computational units in the `Pipeline` through the `split_pipeline_units` method of `DiffusionTrainingModule`. \ No newline at end of file diff --git a/docs/en/Training/Supervised_Fine_Tuning.md b/docs/en/Training/Supervised_Fine_Tuning.md new file mode 100644 index 0000000000000000000000000000000000000000..fd29c10cf34524ddbf7a7c6ccb1b2ce764384009 --- /dev/null +++ b/docs/en/Training/Supervised_Fine_Tuning.md @@ -0,0 +1,129 @@ +# Standard Supervised Training + +After understanding the [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md). + +Recalling the model training pseudocode from earlier, when we actually write code, the situation becomes extremely complex. Some models require additional guidance conditions and preprocessing, such as ControlNet; some models require cross-computation with the denoising model, such as VACE; some models require Gradient Checkpointing due to excessive VRAM demands, such as Qwen-Image's DiT. + +To achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms. + +## Framework Design Concept + +The training module is encapsulated on top of the `Pipeline`, inheriting `DiffusionTrainingModule` from `diffsynth.diffusion.training_module`. We need to provide the necessary `__init__` and `forward` methods for the training module. Taking Qwen-Image's LoRA training as an example, we provide a simple script containing only basic training functions in `examples/qwen_image/model_training/special/simple/train.py` to help developers understand the design concept of the training module. + +```python +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__(self, device): + # Initialize models here. + pass + + def forward(self, data): + # Compute loss here. + return loss +``` + +### `__init__` + +In `__init__`, model initialization is required. First load the model, then switch it to training mode. + +```python + def __init__(self, device): + super().__init__() + # Load the pipeline + self.pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + ) + # Switch to training mode + self.switch_pipe_to_training_mode( + self.pipe, + lora_base_model="dit", + lora_target_modules="to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj", + lora_rank=32, + ) +``` + +The logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). + +`switch_pipe_to_training_mode` can switch the model to training mode. See `switch_pipe_to_training_mode` for details. + +### `forward` + +In `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](/docs/en/Developer_Guide/Building_a_Pipeline.md#model_fn). + +```python + def forward(self, data): + # Preprocess + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": True, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + # Loss + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss +``` + +The preprocessing process is consistent with the inference phase. Developers only need to assume they are using the `Pipeline` for inference and fill in the input parameters. + +The loss function calculation reuses `FlowMatchSFTLoss` from `diffsynth.diffusion.loss`. + +### Starting Training + +The training framework requires other modules, including: + +* accelerator: Training launcher provided by `accelerate`, see [`accelerate`](https://huggingface.co/docs/accelerate/index) for details +* dataset: Generic dataset, see [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md) for details +* model_logger: Model logger, see `diffsynth.diffusion.logger` for details + +```python +if __name__ == "__main__": + accelerator = accelerate.Accelerator( + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)], + ) + dataset = UnifiedDataset( + base_path="data/example_image_dataset", + metadata_path="data/example_image_dataset/metadata.csv", + repeat=50, + data_file_keys="image", + main_data_operator=UnifiedDataset.default_image_operator( + base_path="data/example_image_dataset", + height=512, + width=512, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule(accelerator.device) + model_logger = ModelLogger( + output_path="models/toy_model", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=1e-5, num_epochs=1, + ) +``` + +Assembling all the above code results in `examples/qwen_image/model_training/special/simple/train.py`. Use the following command to start training: + +``` +accelerate launch examples/qwen_image/model_training/special/simple/train.py +``` \ No newline at end of file diff --git a/docs/en/Training/Understanding_Diffusion_models.md b/docs/en/Training/Understanding_Diffusion_models.md new file mode 100644 index 0000000000000000000000000000000000000000..5c81b6a7defc489cfb18688ebcb90a9bed34f0f4 --- /dev/null +++ b/docs/en/Training/Understanding_Diffusion_models.md @@ -0,0 +1,145 @@ +# Basic Principles of Diffusion Models + +This document introduces the basic principles of Diffusion models to help you understand how the training framework is constructed. To make these complex mathematical theories easier for readers to understand, we have reconstructed the theoretical framework of Diffusion models, abandoning complex stochastic differential equations and presenting them in a more concise and understandable form. + +## Introduction + +Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$. + +(Figure) + +This process is intuitive, but to understand the details, we need to answer several questions: + +* How is the noise content at each step defined? +* How is the iterative denoising computation performed? +* How to train such Diffusion models? +* What is the architecture of modern Diffusion models? +* How does this project encapsulate and implement model training? + +## How is the noise content at each step defined? + +In the theoretical system of Diffusion models, the noise content is determined by a series of parameters $\sigma_T$, $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma_0$. Where: + +* $\sigma_T=1$, corresponding to $x_T$ as pure Gaussian noise +* $\sigma_T>\sigma_{T-1}>\sigma_{T-2}>\cdots>x_0$, the noise content gradually decreases during iteration +* $\sigma_0=0$, corresponding to $x_0$ as a data sample without any noise + +As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma_1$, they are not fixed and only need to satisfy the decreasing condition. + +At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$. + +(Figure) + +## How is the iterative denoising computation performed? + +Before understanding the iterative denoising computation, we need to clarify what the input and output of the denoising model are. We abstract the model as a symbol $\hat \epsilon$, whose input typically consists of three parts: + +* Time step $t$, the model needs to understand which stage of the denoising process it is currently in +* Noisy data sample $x_t$, the model needs to understand what data to denoise +* Guidance condition $c$, the model needs to understand what kind of data sample to generate through denoising + +Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure. + +(Figure) + +The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising). + +Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$: + +$$ +\begin{aligned} +x_{t-1}&=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)\\ +&\approx x_t + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\ +&=(1-\sigma_t)x_0+\sigma_t x_T + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\ +&=(1-\sigma_{t-1})x_0+\sigma_{t-1}x_T +\end{aligned} +$$ + +Perfect! It perfectly matches the noise content definition at time step $t-1$. + +> (This part might be a bit difficult to understand. Don't worry; it's recommended to skip this part on first reading without affecting the rest of the document.) +> +> After completing this somewhat complex formula derivation, let's consider a question: why should the model's output approximately equal $x_T-x_0$? Can it be set to other values? +> +> Actually, Diffusion models rely on two definitions to form a complete theory. From the above formulas, we can extract these two definitions and derive the iterative formula: +> +> * Data definition: $x_t=(1-\sigma_t)x_0+\sigma_t x_T$ +> * Model definition: $\hat \epsilon(x_t,c,t)=x_T-x_0$ +> * Derived iterative formula: $x_{t-1}=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)$ +> +> These three mathematical formulas are complete. For example, in the previous derivation, substituting the data definition and model definition into the iterative formula yields $x_{t-1}$ that matches the data definition. +> +> These are two definitions built on Flow Matching theory, but Diffusion models can also be implemented with other definitions. For example, early models based on DDPM (Denoising Diffusion Probabilistic Models) have their two definitions and derived iterative formulas as: +> +> * Data definition: $x_t=\sqrt{\alpha_t}x_0+\sqrt{1-\alpha_t}x_T$ +> * Model definition: $\hat \epsilon(x_t,c,t)=x_T$ +> * Derived iterative formula: $x_{t-1}=\sqrt{\alpha_{t-1}}\left(\frac{x_t-\sqrt{1-\alpha_t}\hat \epsilon(x_t,c,t)}{\sqrt{\sigma_t}}\right)+\sqrt{1-\alpha_{t-1}}\hat \epsilon(x_t,c,t)$ +> +> More generally, we describe the derivation process of the iterative formula using matrices. For any data definition and model definition: +> +> * Data definition: $x_t=C_T(x_0,x_T)^T$ +> * Model definition: $\hat \epsilon(x_t,c,t)=C_T^{[\epsilon]}(x_0,x_T)^T$ +> * Derived iterative formula: $x_{t-1}=C_{t-1}(C_t,C_t^{[\epsilon]})^{-T}(x_t,\hat \epsilon(x_t,c,t))^T$ +> +> Where $C_t$ and $C_t^{[\epsilon]}$ are $1\times 2$ coefficient matrices. It's not difficult to see that when constructing the two definitions, the matrix $(C_t,C_t^{[\epsilon]})^T$ must be invertible. +> +> Although Flow Matching and DDPM have been widely verified by numerous pre-trained models, this doesn't mean they are optimal solutions. We encourage developers to design new Diffusion model theories for better training results. + +## How to train such Diffusion models? + +After understanding the iterative denoising process, we next consider how to train such Diffusion models. + +The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training. + +(Figure) + +The following is pseudocode for the training process: + +> Obtain data sample $x_0$ and guidance condition $c$ from the dataset +> +> Randomly sample time step $t\in(0,T]$ +> +> Randomly sample Gaussian noise $x_T\in \mathcal N(O,I)$ +> +> $x_t=(1-\sigma_t)x_0+\sigma_t x_T$ +> +> $\hat \epsilon(x_t,c,t)$ +> +> Loss function $\mathcal L=||\hat \epsilon(x_t,c,t)-(x_T-x_0)||_2^2$ +> +> Backpropagate gradients and update model parameters + +## What is the architecture of modern Diffusion models? + +From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model. + +(Figure) + +### Data Encoder-Decoder + +In the previous text, we consistently referred to $x_0$ as a "data sample" rather than an image or video because modern Diffusion models typically don't process images or videos directly. Instead, they use an Encoder-Decoder architecture model, usually a VAE (Variational Auto-Encoders) model, to encode images or videos into Embedding tensors, obtaining $x_0$. + +After data is encoded by the encoder and then decoded by the decoder, the reconstructed content is approximately consistent with the original, with minor errors. So why process on the encoded Embedding tensor instead of directly on images or videos? The main reasons are twofold: + +* Encoding compresses the data simultaneously, reducing computational load during processing. +* Encoded data distribution is more similar to Gaussian distribution, making it easier for denoising models to model the data. + +During generation, the encoder part doesn't participate in computation. After iteration completes, the decoder part decodes $x_0$ to obtain clear images or videos. During training, the decoder part doesn't participate in computation; only the encoder is used to compute $x_0$. + +### Guidance Condition Encoder + +User-input guidance conditions $c$ can be complex and diverse, requiring specialized encoder models to process them into Embedding tensors. According to the type of guidance condition, we classify guidance condition encoders into the following categories: + +* Text type, such as CLIP, Qwen-VL +* Image type, such as ControlNet, IP-Adapter +* Video type, such as VAE + +> The model $\hat \epsilon$ mentioned in the previous text refers to the entirety of all guidance condition encoders and the denoising model. We list guidance condition encoders separately because these models are typically frozen during Diffusion training, and their output values are independent of time step $t$, allowing guidance condition encoder computations to be performed offline. + +### Denoising Model + +The denoising model is the true essence of Diffusion models, with diverse model structures such as UNet and DiT. Model developers can freely innovate on these structures. + +## How does this project encapsulate and implement model training? + +Please read the next document: [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) \ No newline at end of file diff --git a/docs/wan_video_instance_control_design.md b/docs/wan_video_instance_control_design.md new file mode 100644 index 0000000000000000000000000000000000000000..a32d131d86002c8e95adc16d2d92170452500482 --- /dev/null +++ b/docs/wan_video_instance_control_design.md @@ -0,0 +1,251 @@ +# Wan Video Instance Control:模型设计说明(bbox + per-frame state weights) + +本文档描述当前 DiffSynth-Studio 的 **Wan Video Statemachine + Instance Control** 设计:仅使用 + +- `instance_ids`(区分同类不同个体) +- `instance_class_text`(每个实例的 tag/class 文本) +- `instance_state_texts`(每个实例的 **固定** state 文本集合) +- `instance_state_weights`(**逐帧** state 权重,允许软融合) +- `instance_bboxes`(**逐帧** 2D bbox,xyxy 像素坐标) + +来驱动 DiT 中的 instance-aware cross attention。除以上输入外,其它 instance 相关字段(`class_id/state_id/mask/state_a/b/progress` 等)不再使用。 + +--- + +## 1. 入口 API 与张量约定 + +入口在 `diffsynth/pipelines/wan_video_statemachine.py` 的 `WanVideoPipeline.__call__`。 + +### 1.1 必需字段(启用 instance control 时) + +当你传入以下任意一个字段不为 `None` 时,pipeline 视为启用 instance control,并要求 **全部提供**: + +- `instance_ids`: `Tensor`,形状 `(B, N)` 或 `(N,)`,dtype `long` + - `N`:实例数(objects) +- `instance_class_text`: `List[str]`(长度 `N`)或 `str`(单实例) + - 每个实例一个 class/tag,例如 `"egg"`, `"dog"`, `"person"`… +- `instance_state_texts`: `List[List[str]]`(形状 `N × S`)或 `List[str]`(单实例) + - 每个实例有一个 **固定大小** 的 state 候选集合(`S` 个 state 文本) + - 例如单实例:`["raw", "cooked"]`;多实例:`[["idle","run"], ["open","close"]]` + - 约束:所有实例的 `S` 必须相同(当前实现强制)。 +- `instance_state_weights`: `Tensor`/list,形状 `(B, N, F, S)` 或 `(N, F, S)`,dtype `float` + - `F`:逐帧权重的时间长度(推荐等于输入视频帧数 `num_frames`,但允许不同,后续会映射/下采样到 patch-time) + - `S`:state 数量,必须等于 `instance_state_texts` 的 state 数 + - 语义:对每个 `(b,n,f)`,给出 `S` 个 state 的权重(可 one-hot,也可软融合) +- `instance_bboxes`: `Tensor`,形状 `(B, N, F, 4)` 或 `(N, F, 4)`,dtype `float` + - bbox 是 `xyxy`,单位为像素坐标,坐标系必须与推理时的 `height/width` 对齐 + - 约束:`instance_bboxes.shape[2]` 必须等于 `instance_state_weights.shape[2]`(同一个 `F`) + +### 1.2 推荐的常见配置 + +- **单实例**(N=1)+ 两状态(S=2): + - `instance_class_text="egg"` + - `instance_state_texts=["raw","cooked"]` + - `instance_state_weights.shape=(1,1,F,2)` + - `instance_bboxes.shape=(1,1,F,4)` +- **多实例**(N>1): + - `instance_class_text` 长度必须与 `N` 相同 + - `instance_state_texts` 外层长度必须与 `N` 相同 + +--- + +## 2. Pipeline 数据流(从输入到 model_fn) + +对应代码: + +- `diffsynth/pipelines/wan_video_statemachine.py` + - `WanVideoPipeline.__call__` + - `WanVideoUnit_InstanceStateTextEmbedder` + - `model_fn_wan_video` + +### 2.1 参数归一化与校验(__call__) + +在 `__call__` 中会把输入转为 Tensor,并补齐 batch 维: + +- `instance_ids`:若输入为 `(N,)` 会补成 `(1,N)` +- `instance_bboxes`:若输入为 `(N,F,4)` 会补成 `(1,N,F,4)` +- `instance_state_weights`:若输入为 `(N,F,S)` 会补成 `(1,N,F,S)` + +启用 instance control 时会做关键校验: + +- 5 个输入必须同时存在:`ids/class_text/state_texts/state_weights/bboxes` +- `state_weights` 与 `bboxes` 的 `F` 必须一致 + +### 2.2 文本编码(WanVideoUnit_InstanceStateTextEmbedder) + +该 unit 负责把 `(class_text, state_texts)` 变成可供 DiT 使用的 state phrase embedding: + +1. 先构造短语: + - 对每个实例 `n`,对每个 state `s`: + - phrase = `" is "` +2. 使用 T5 encoder 编码短语序列,并做 mask-aware mean pooling 得到每个短语的 pooled embedding: + - 输出 `instance_state_text_embeds_multi`,形状 `(1, N, S, text_dim)` + +注意: + +- 这里不使用 `instance_state_weights` 做融合;融合在 DiT 内根据逐帧权重完成。 +- unit 只产出 `instance_state_text_embeds_multi`,并且 pipeline 在 unit 之后会把 `instance_class_text/instance_state_texts` 从 `inputs_shared` 中移除,确保下游 model_fn 只接收张量(最小化接口)。 + +--- + +## 3. DiT 内部设计(instance tokens + bbox mask-guided attention) + +对应代码: + +- `diffsynth/models/wan_video_dit_instance.py` + - `InstanceFeatureExtractor` + - `MaskGuidedCrossAttention` + - `DiTBlock.forward(..., instance_tokens, instance_masks)` + +### 3.1 从“逐帧权重”生成“按 patch-time 的 instance tokens” + +核心目标:把 per-frame 的 state 权重变成与 DiT patch token 的时间轴一致的 instance tokens,再对每个 patch 做 masked attention。 + +#### 输入 + +- `state_text_embeds_multi`: `(B, N, S, text_dim)` + 每个 state 对应短语 `" is "` 的 pooled embedding +- `state_weights`: `(B, N, F, S)` + 每帧对 `S` 个 state 的权重 +- `instance_ids`: `(B, N)` + 用于区分同类个体 +- `num_time_patches = f` + DiT patchify 后的时间 patch 数(由 `patch_embedding` 决定) + +#### 步骤 + +1. **文本投影到 hidden_dim** + - `sem_multi = text_proj(state_text_embeds_multi)` → `(B, N, S, H)` +2. **权重截断与时间下采样** + - `weights = clamp(state_weights, min=0)` + - 若 `F != f`:把 `(B,N,F,S)` 平均池化到 `(B,N,f,S)` + - 映射规则:`pt = floor(t * f / F)` +3. **按权重对 state 语义做逐时间融合** + - `sem_time[b,n,t] = sum_s( sem_multi[b,n,s] * w[b,n,t,s] ) / sum_s(w)` + - 得到 `(B, N, f, H)` +4. **融合 instance_id embedding** + - `i_feat = Emb(instance_ids)` → `(B, N, H)`,并广播到时间维 `(B, N, f, H)` + - 拼接并通过 fusion MLP: + - `token_time[b,n,t] = fusion( concat(sem_time[b,n,t], i_feat[b,n]) )` + - 输出 `inst_tokens`:`(B, f, N, D)`(注意转置后时间维在前) + +### 3.2 bbox → patch mask(每个 patch 是否被某实例覆盖) + +`WanModel.process_masks` 将 `instance_bboxes` 投影到 patch token 网格,返回 `inst_mask_flat`: + +- 输入 bbox:`(B, N, F, 4)`,`xyxy` 像素坐标 +- patch 网格:`(f_p, h_p, w_p)` +- 输出 mask:`(B, N, L)`,其中 `L = f_p * h_p * w_p` + +关键映射规则: + +- 空间缩放: + - `px = x * (w_p / W_img)` + - `py = y * (h_p / H_img)` +- 时间映射: + - `pt = floor(t * f_p / F_bbox)` + +最终每个 `(b,n,pt)` 上把 bbox 覆盖到的 `(py0:py1, px0:px1)` patch 置 1。 + +### 3.3 MaskGuidedCrossAttention(log-mask trick) + +每个 DiT block 都包含一个 instance cross attention: + +- Q:来自 patch tokens `x`(形状 `(B, L, D)`) +- K/V:来自 instance tokens(按时间对齐后使用) +- Mask:`(B, N, L)` + +attention logits 里加入 `log(mask)` 作为 bias: + +- `sim = (q · k) / sqrt(d)` +- `sim = sim + log(mask.clamp(min=1e-6))` + +这样 mask=0 的位置会得到接近 `-inf` 的 bias,从而 softmax 后强制为 0,实现 **只让每个 patch 关注覆盖它的实例**。 + +### 3.4 时间对齐方式(per-time tokens vs per-token tokens) + +`MaskGuidedCrossAttention` 支持三种形状: + +- `(B, N, D)`:整段序列共享同一组 instance tokens(当前不用) +- `(B, T, N, D)`:按 patch-time 切分(默认路径) + - 假设序列按时间展开:`L = T * (h*w)`,按时间分段计算 attention +- `(B, L, N, D)`:按 token 位置提供 instance tokens(用于 Unified Sequence Parallel) + +在 `model_fn_wan_video` 开启 USP 时会将 `inst_tokens (B,T,N,D)` 转换成当前 rank 的 chunk 对应的 `(B, chunk_len, N, D)`: + +- 先计算每个 token 在全局序列里的位置 `global_pos` +- `time_index = global_pos // (h*w)` +- `inst_tokens_chunk = inst_tokens[:, time_index]` + +并对 padding 部分置 0,避免污染。 + +### 3.5 Reference latents 的处理 + +当 pipeline 使用 `reference_latents` 拼到序列前面时: + +- patch token 序列会多出 1 个时间片(`f += 1`) +- `inst_mask_flat` 会在序列前补 0(reference 部分不属于任何 instance) +- `inst_tokens` 也会在时间维前补 0(reference 时间片不注入 instance 语义) + +--- + +## 4. 重要限制与注意事项 + +1. **必须给每个实例提供相同数量的 state 文本(S 必须一致)** +2. **`instance_state_weights` 与 `instance_bboxes` 的时间长度 `F` 必须一致** +3. **bbox 的像素坐标必须与推理时的 `height/width` 对齐** + - 如果 pipeline 会 resize 输入图像/视频,你需要用 resize 后的坐标系提供 bbox +4. **sliding window 不支持 instance control** + - `model_fn_wan_video` 在 `sliding_window_size/stride` 与 instance 输入同时存在时直接报错 + +--- + +## 5. 最小可用示例(伪代码) + +```python +F = num_frames +N = 1 +S = 3 + +instance_ids = torch.tensor([[1]]) # (1,1) +instance_class_text = ["egg"] # len=1 +instance_state_texts = [["raw", "half", "cooked"]] # (N,S) + +# 逐帧权重 (1,1,F,3):例如线性从 raw -> cooked +w = torch.zeros((1,1,F,S), dtype=torch.float32) +t = torch.linspace(0, 1, F) +w[0,0,:,0] = (1 - t) # raw +w[0,0,:,2] = t # cooked +w[0,0,:,1] = 0.0 # half (可选) + +# bbox (1,1,F,4):每帧一个 bbox,xyxy +b = torch.zeros((1,1,F,4), dtype=torch.float32) +b[0,0,:,0] = 100; b[0,0,:,1] = 120 +b[0,0,:,2] = 260; b[0,0,:,3] = 320 + +video = pipe( + prompt="...", + height=H, width=W, num_frames=F, + instance_ids=instance_ids, + instance_class_text=instance_class_text, + instance_state_texts=instance_state_texts, + instance_state_weights=w, + instance_bboxes=b, +) +``` + +--- + +## 6. 代码入口索引 + +- Pipeline API / 文本编码: + - `diffsynth/pipelines/wan_video_statemachine.py` + - `WanVideoPipeline.__call__` + - `WanVideoUnit_InstanceStateTextEmbedder` + - `model_fn_wan_video` +- Instance-aware DiT: + - `diffsynth/models/wan_video_dit_instance.py` + - `InstanceFeatureExtractor` + - `MaskGuidedCrossAttention` + - `WanModel.forward(... instance_*)` + diff --git a/docs/zh/API_Reference/core/attention.md b/docs/zh/API_Reference/core/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..c30e180fa92b5c7ed3aa44fe37dfeceda551451b --- /dev/null +++ b/docs/zh/API_Reference/core/attention.md @@ -0,0 +1,79 @@ +# `diffsynth.core.attention`: 注意力机制实现 + +`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。 + +## 注意力机制 + +注意力机制是在论文[《Attention Is All You Need》](https://arxiv.org/abs/1706.03762)中提出的模型结构,在原论文中,注意力机制按照如下公式实现: + +$$ +\text{Attention}(Q, K, V) = \text{Softmax}\left( + \frac{QK^T}{\sqrt{d_k}} +\right) +V. +$$ + +在 `PyTorch` 中,可以用如下代码实现: +```python +import torch + +def attention(query, key, value): + scale_factor = 1 / query.size(-1)**0.5 + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + +query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +output_1 = attention(query, key, value) +``` + +其中 `query`、`key`、`value` 的维度是 $(b, n, s, d)$: +* $b$:Batch size +* $n$: Attention head 的数量 +* $s$: 序列长度 +* $d$: 每个 Attention head 的维数 + +这部分计算是不包含任何可训练参数的,现代 transformer 架构的模型会在进行这一计算前后经过 Linear 层,本文讨论的“注意力机制”不包含这些计算,仅包含以上代码的计算。 + +## 更高效的实现 + +注意到,注意力机制中 Attention Score(公式中的 $\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$,代码中的 `attn_weight`)的维度为 $(b, n, s, s)$,其中序列长度 $s$ 通常非常大,这导致计算的时间和空间复杂度达到平方级。以图像生成模型为例,图像的宽度和高度每增加到 2 倍,序列长度增加到 4 倍,计算量和显存需求增加到 16 倍。为了避免高昂的计算成本,需采用更高效的注意力机制实现,包括 +* Flash Attention 3:[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2407.08608) +* Flash Attention 2:[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2307.08691) +* Sage Attention:[GitHub](https://github.com/thu-ml/SageAttention)、[论文](https://arxiv.org/abs/2505.11594) +* xFormers:[GitHub](https://github.com/facebookresearch/xformers)、[文档](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops) +* PyTorch:[GitHub](https://github.com/pytorch/pytorch)、[文档](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) + +如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。 + +```python +from diffsynth.core.attention import attention_forward +import torch + +def attention(query, key, value): + scale_factor = 1 / query.size(-1)**0.5 + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + +query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda") +output_1 = attention(query, key, value) +output_2 = attention_forward(query, key, value) +print((output_1 - output_2).abs().mean()) +``` + +请注意,加速的同时会引入误差,但在大多数情况下误差是可以忽略不计的。 + +## 开发者导引 + +在为 `DiffSynth-Studio` 接入新模型时,开发者可自行决定是否调用 `diffsynth.core.attention` 中的 `attention_forward`,但我们期望模型能够尽可能优先调用这一模块,以便让新的注意力机制实现能够在这些模型上直接生效。 + +## 最佳实践 + +**在大多数情况下,我们建议直接使用 `PyTorch` 原生的实现,无需安装任何额外的包。** 虽然其他注意力机制实现可以加速,但加速效果是较为有限的,在少数情况下会出现兼容性和精度不足的问题。 + +此外,高效的注意力机制实现会逐步集成到 `PyTorch` 中,`PyTorch` 的 `2.9.0` 版本中的 `scaled_dot_product_attention` 已经集成了 Flash Attention 2。我们仍在 `DiffSynth-Studio` 提供这一接口,是为了让一些激进的加速方案能够快速走向应用,尽管它们在稳定性上还需要时间验证。 diff --git a/docs/zh/API_Reference/core/data.md b/docs/zh/API_Reference/core/data.md new file mode 100644 index 0000000000000000000000000000000000000000..60500a736ef2ba0e57bbc2284c6b5b33f9b928ab --- /dev/null +++ b/docs/zh/API_Reference/core/data.md @@ -0,0 +1,151 @@ +# `diffsynth.core.data`: 数据处理算子与通用数据集 + +## 数据处理算子 + +### 可用数据处理算子 + +`diffsynth.core.data` 提供了一系列数据处理算子,用于进行数据处理,包括: + +* 数据格式转换算子 + * `ToInt`: 转换为 int 格式 + * `ToFloat`: 转换为 float 格式 + * `ToStr`: 转换为 str 格式 + * `ToList`: 转换为列表格式,以列表包裹此数据 + * `ToAbsolutePath`: 将相对路径转换为绝对路径 +* 文件加载算子 + * `LoadImage`: 读取图片文件 + * `LoadVideo`: 读取视频文件 + * `LoadAudio`: 读取音频文件 + * `LoadGIF`: 读取 GIF 文件 + * `LoadTorchPickle`: 读取由 [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) 保存的二进制文件【该算子可能导致二进制文件中的代码注入攻击,请谨慎使用!】 +* 媒体文件处理算子 + * `ImageCropAndResize`: 对图像进行裁剪和拉伸 +* Meta 算子 + * `SequencialProcess`: 将序列中的每个数据路由到一个算子 + * `RouteByExtensionName`: 按照文件扩展名路由到特定算子 + * `RouteByType`: 按照数据类型路由到特定算子 + +### 算子使用 + +数据算子之间以 `>>` 符号连接形成数据处理流水线,例如: + +```python +from diffsynth.core.data.operators import * + +data = "image.jpg" +data_pipeline = ToAbsolutePath(base_path="/data") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512) +data = data_pipeline(data) +``` + +在经过每个算子后,数据被依次处理 + +* `ToAbsolutePath(base_path="/data")`: `"/data/image.jpg"` +* `LoadImage()`: `` +* `ImageCropAndResize(max_pixels=512*512)`: `` + +我们可以组合出功能完备的数据流水线,例如通用数据集的默认视频数据算子为 + +```python +RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), +]) +``` + +它包含如下逻辑: + +* 如果是 `str` 类型的数据 + * 如果是 `"jpg", "jpeg", "png", "webp"` 类型文件 + * 加载这张图片 + * 裁剪并缩放到特定分辨率 + * 打包进列表,视为单帧视频 + * 如果是 `"gif"` 类型文件 + * 加载 gif 文件内容 + * 将每一帧裁剪和缩放到特定分辨率 + * 如果是 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"` 类型文件 + * 加载 gif 文件内容 + * 将每一帧裁剪和缩放到特定分辨率 +* 如果不是 `str` 类型的数据,报错 + +## 通用数据集 + +`diffsynth.core.data` 提供了统一的数据集实现,数据集需输入以下参数: + +* `base_path`: 根目录,若数据集中包含图片文件的相对路径,则需填入此字段用于加载这些路径指向的文件 +* `metadata_path`: 元数据目录,记录所有元数据的文件路径,支持 `csv`、`json`、`jsonl` 格式 +* `repeat`: 数据重复次数,默认为 1,该参数影响一个 epoch 的训练步数 +* `data_file_keys`: 需进行加载的数据字段名,例如 `(image, edit_image)` +* `main_data_operator`: 主加载算子,需通过数据处理算子组装好数据处理流水线 +* `special_operator_map`: 特殊算子映射,对需要特殊处理的字段构建的算子映射 + +### 元数据 + +数据集的 `metadata_path` 指向元数据文件,支持 `csv`、`json`、`jsonl` 格式,以下提供了样例 + +* `csv` 格式:可读性高、不支持列表数据、内存占用小 + +```csv +image,prompt +image_1.jpg,"a dog" +image_2.jpg,"a cat" +``` + +* `json` 格式:可读性高、支持列表数据、内存占用大 + +```json +[ + { + "image": "image_1.jpg", + "prompt": "a dog" + }, + { + "image": "image_2.jpg", + "prompt": "a cat" + } +] +``` + +* `jsonl` 格式:可读性低、支持列表数据、内存占用小 + +```json +{"image": "image_1.jpg", "prompt": "a dog"} +{"image": "image_2.jpg", "prompt": "a cat"} +``` + +如何选择最佳的元数据格式? + +* 如果数据量大,达到千万级的数据量,由于 `json` 文件解析时需要额外内存,此时不可用,请使用 `csv` 或 `jsonl` 格式 +* 如果数据集中包含列表数据,例如编辑模型需输入多张图,由于 `csv` 格式无法存储列表格式数据,此时不可用,请使用 `json` 或 `jsonl` 格式 + +### 数据加载逻辑 + +在没有进行额外设置时,数据集默认输出元数据集中的数据,图片和视频文件的路径会以字符串的格式输出,若要加载这些文件,则需要设置 `data_file_keys`、`main_data_operator`、`special_operator_map`。 + +在数据处理流程中,按如下逻辑进行处理: +* 如果字段位于 `special_operator_map`,则调用 `special_operator_map` 中的对应算子进行处理 +* 如果字段不位于 `special_operator_map` + * 如果字段位于 `data_file_keys`,则调用 `main_data_operator` 算子进行处理 + * 如果字段不位于 `data_file_keys`,则不进行处理 + +`special_operator_map` 可用于实现特殊的数据处理,例如模型 [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) 中输入的人物面部视频 `animate_face_video` 是以固定分辨率处理的,与输出视频不一致,因此这一字段由专门的算子处理: + +```python +special_operator_map={ + "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), +} +``` + +### 其他注意事项 + +当数据量过少时,可适当增加 `repeat`,延长单个 epoch 的训练时间,避免频繁保存模型产生较多耗时。 + +当数据量 * `repeat` 超过 $10^9$ 时,我们观测到数据集的速度明显变慢,这似乎是 `PyTorch` 的 bug,我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。 diff --git a/docs/zh/API_Reference/core/gradient.md b/docs/zh/API_Reference/core/gradient.md new file mode 100644 index 0000000000000000000000000000000000000000..f92f6e8c855b832a473273d1bcf02eb606f5925d --- /dev/null +++ b/docs/zh/API_Reference/core/gradient.md @@ -0,0 +1,69 @@ +# `diffsynth.core.gradient`: 梯度检查点及其 Offload + +`diffsynth.core.gradient` 中提供了封装好的梯度检查点及其 Offload 版本,用于模型训练。 + +## 梯度检查点 + +梯度检查点是用于减少训练时显存占用的技术。我们提供一个例子来帮助你理解这一技术,以下是一个简单的模型结构 + +```python +import torch + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + return self.activation(x) + +model = ToyModel() +x = torch.randn((2, 3)) +y = model(x) +``` + +在这个模型结构中,输入的参数 $x$ 经过 Sigmoid 激活函数得到输出值 $y=\frac{1}{1+e^{-x}}$。 + +在训练过程中,假定我们的损失函数值为 $\mathcal L$,在梯度反响传播时,我们得到 $\frac{\partial \mathcal L}{\partial y}$,此时我们需计算 $\frac{\partial \mathcal L}{\partial x}$,不难发现 $\frac{\partial y}{\partial x}=y(1-y)$,进而有 $\frac{\partial \mathcal L}{\partial x}=\frac{\partial \mathcal L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial \mathcal L}{\partial y}y(1-y)$。如果在模型前向传播时保存 $y$ 的数值,并在梯度反向传播时直接计算 $y(1-y)$,这将避免复杂的 exp 计算,加快计算速度,但这会导致我们需要额外的显存来存储中间变量 $y$。 + +不启用梯度检查点时,训练框架会默认存储所有辅助梯度计算的中间变量,从而达到最佳的计算速度。开启梯度检查点时,中间变量则不会存储,但输入参数 $x$ 仍会存储,减少显存占用,在梯度反向传播时需重新计算这些变量,减慢计算速度。 + +## 启用梯度检查点及其 Offload + +`diffsynth.core.gradient` 中的 `gradient_checkpoint_forward` 实现了梯度检查点及其 Offload,可参考以下代码调用: + +```python +import torch +from diffsynth.core.gradient import gradient_checkpoint_forward + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + return self.activation(x) + +model = ToyModel() +x = torch.randn((2, 3)) +y = gradient_checkpoint_forward( + model, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + x=x, +) +``` + +* 当 `use_gradient_checkpointing=False` 且 `use_gradient_checkpointing_offload=False` 时,计算过程与原始计算完全相同,不影响模型的推理和训练,你可以直接将其集成到代码中。 +* 当 `use_gradient_checkpointing=True` 且 `use_gradient_checkpointing_offload=False` 时,启用梯度检查点。 +* 当 `use_gradient_checkpointing_offload=True` 时,启用梯度检查点,所有梯度检查点的输入参数存储在内存中,进一步降低显存占用和减慢计算速度。 + +## 最佳实践 + +> Q: 应当在何处启用梯度检查点? +> +> A: 对整个模型启用梯度检查点时,计算效率和显存占用并不是最优的,我们需要设置细粒度的梯度检查点,但同时不希望为框架增加过多繁杂的代码。因此我们建议在 `Pipeline` 的 `model_fn` 中实现,例如 `diffsynth/pipelines/qwen_image.py` 中的 `model_fn_qwen_image`,在 Block 层级启用梯度检查点,不需要修改模型结构的任何代码。 + +> Q: 什么情况下需要启用梯度检查点? +> +> A: 随着模型参数量越来越大,梯度检查点已成为必要的训练技术,梯度检查点通常是需要启用的。梯度检查点的 Offload 则仅需在激活值占用显存过大的模型(例如视频生成模型)中启用。 diff --git a/docs/zh/API_Reference/core/loader.md b/docs/zh/API_Reference/core/loader.md new file mode 100644 index 0000000000000000000000000000000000000000..ad2d245dfc9da3febd95701dd9b7b03b5e46ab86 --- /dev/null +++ b/docs/zh/API_Reference/core/loader.md @@ -0,0 +1,141 @@ +# `diffsynth.core.loader`: 模型下载与加载 + +本文档介绍 `diffsynth.core.loader` 中模型下载与加载相关的功能。 + +## ModelConfig + +`diffsynth.core.loader` 中的 `ModelConfig` 用于标注模型下载来源、本地路径、显存管理配置等信息。 + +### 从远程下载并加载模型 + +以模型[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 为例,在 `ModelConfig` 中填写 `model_id` 和 `origin_file_pattern` 后即可自动下载模型。默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 + +默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 + +```python +from diffsynth.core import ModelConfig + +config = ModelConfig( + model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", + origin_file_pattern="model.safetensors", +) +# Download models +config.download_if_necessary() +print(config.path) +``` + +调用 `download_if_necessary` 后,模型会自动下载,并将路径返回到 `config.path` 中。 + +### 从本地路径加载模型 + +如果从本地路径加载模型,则需要填入 `path`: + +```python +from diffsynth.core import ModelConfig + +config = ModelConfig(path="models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors") +``` + +如果模型包含多个分片文件,以列表的形式输入即可: + +```python +from diffsynth.core import ModelConfig + +config = ModelConfig(path=[ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +]) +``` + +### 显存管理配置 + +`ModelConfig` 也包含了显存管理配置信息,详见[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)。 + +## 模型文件加载 + +`diffsynth.core.loader` 提供了统一的 `load_state_dict`,用于加载模型文件中的 state dict。 + +加载单个模型文件: + +```python +from diffsynth.core import load_state_dict + +state_dict = load_state_dict("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors") +``` + +加载多个模型文件(合并为一个 state dict): + +```python +from diffsynth.core import load_state_dict + +state_dict = load_state_dict([ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +]) +``` + +## 模型哈希 + +模型哈希是用于判断模型类型的,哈希值可通过 `hash_model_file` 获取: + +```python +from diffsynth.core import hash_model_file + +print(hash_model_file("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")) +``` + +也可计算多个模型文件的哈希值,等价于合并 state dict 后计算模型哈希值: + +```python +from diffsynth.core import hash_model_file + +print(hash_model_file([ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +])) +``` + +模型哈希值只与模型文件中 state dict 的 keys 和 tensor shape 有关,与模型参数的数值、文件保存时间等信息无关。在计算 `.safetensors` 格式文件的模型哈希值时,`hash_model_file` 是几乎瞬间完成的,无需读取模型的参数;但在计算 `.bin`、`.pth`、`.ckpt` 等二进制文件的模型哈希值时,则需要读取全部模型参数,因此**我们不建议开发者继续使用这些格式的文件。** + +通过[编写模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config)并将模型哈希值等信息填入 `diffsynth/configs/model_configs.py`,开发者可以让 `DiffSynth-Studio` 自动识别模型类型并加载。 + +## 模型加载 + +`load_model` 是 `diffsynth.core.loader` 中加载模型的外部入口,它会调用 [skip_model_initialization](/docs/zh/API_Reference/core/vram.md#跳过模型参数初始化) 跳过模型参数初始化。如果启用了 [Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload),则调用 [DiskMap](/docs/zh/API_Reference/core/vram.md#state-dict-硬盘映射) 进行惰性加载;如果没有启用 Disk Offload,则调用 [load_state_dict](#模型文件加载) 加载模型参数。如果需要的话,还会调用 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换) 进行模型格式转换。最后调用 `model.eval()` 将其切换到推理模式。 + +以下是一个启用了 Disk Offload 的使用案例: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] + +model = load_model( + QwenImageDiT, + model_path, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config={ + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +``` diff --git a/docs/zh/API_Reference/core/vram.md b/docs/zh/API_Reference/core/vram.md new file mode 100644 index 0000000000000000000000000000000000000000..f79b9dac4e1c50722ce5240c4e9432f0bceb17ac --- /dev/null +++ b/docs/zh/API_Reference/core/vram.md @@ -0,0 +1,66 @@ +# `diffsynth.core.vram`: 显存管理 + +本文档介绍 `diffsynth.core.vram` 中的显存管理底层功能,如果你希望将这些功能用于其他的代码库中,可参考本文档。 + +## 跳过模型参数初始化 + +在 `PyTorch` 中加载模型时,模型的参数默认会占用显存或内存并进行参数初始化,而这些参数会在加载预训练权重后被覆盖掉,这导致了冗余的计算。`PyTorch` 中没有提供接口来跳过这些冗余的计算,我们在 `diffsynth.core.vram` 中提供了 `skip_model_initialization` 用于跳过模型参数初始化。 + +默认的模型加载方式: + +```python +from diffsynth.core import load_state_dict +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet + +model = QwenImageBlockWiseControlNet() # Slow +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = load_state_dict(path, device="cpu") +model.load_state_dict(state_dict, assign=True) +``` + +跳过参数初始化的模型加载方式: + +```python +from diffsynth.core import load_state_dict, skip_model_initialization +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet + +with skip_model_initialization(): + model = QwenImageBlockWiseControlNet() # Fast +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = load_state_dict(path, device="cpu") +model.load_state_dict(state_dict, assign=True) +``` + +在 `DiffSynth-Studio` 中,所有预训练模型都遵循这一加载逻辑。开发者在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)完毕后即可直接以这种方式快速加载模型。 + +## State Dict 硬盘映射 + +对于某个模型的预训练权重文件,如果我们只需要读取其中的一组参数,而非全部参数,State Dict 硬盘映射可以加速这一过程。我们在 `diffsynth.core.vram` 中提供了 `DiskMap` 用于按需加载模型参数。 + +默认的权重加载方式: + +```python +from diffsynth.core import load_state_dict + +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = load_state_dict(path, device="cpu") # Slow +print(state_dict["img_in.weight"]) +``` + +使用 `DiskMap` 只加载特定参数: + +```python +from diffsynth.core import DiskMap + +path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors" +state_dict = DiskMap(path, device="cpu") # Fast +print(state_dict["img_in.weight"]) +``` + +`DiskMap` 是 `DiffSynth-Studio` 中 Disk Offload 的基本组件,开发者在[配置细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)后即可直接启用 Disk Offload。 + +`DiskMap` 是利用 `.safetensors` 文件的特性实现的功能,因此在使用 `.bin`、`.pth`、`.ckpt` 等二进制文件时,模型的参数是全量加载的,这也导致 Disk Offload 不支持这些格式的文件。**我们不建议开发者继续使用这些格式的文件。** + +## 显存管理可替换模块 + +在启用 `DiffSynth-Studio` 的显存管理后,模型内部的模块会被替换为 `diffsynth.core.vram.layers` 中的可替换模块,其使用方式详见[细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md#编写细粒度显存管理方案)。 diff --git a/docs/zh/Developer_Guide/Building_a_Pipeline.md b/docs/zh/Developer_Guide/Building_a_Pipeline.md new file mode 100644 index 0000000000000000000000000000000000000000..cac5b62a289c8aaea0f93b70e91c268a5022439e --- /dev/null +++ b/docs/zh/Developer_Guide/Building_a_Pipeline.md @@ -0,0 +1,250 @@ +# 接入 Pipeline + +在[将 Pipeline 所需的模型接入](/docs/zh/Developer_Guide/Integrating_Your_Model.md)之后,还需构建 `Pipeline` 用于模型推理,本文档提供 `Pipeline` 构建的标准化流程,开发者也可参考现有的 `Pipeline` 进行构建。 + +`Pipeline` 的实现位于 `diffsynth/pipelines`,每个 `Pipeline` 包含以下必要的关键组件: + +* `__init__` +* `from_pretrained` +* `__call__` +* `units` +* `model_fn` + +## `__init__` + +在 `__init__` 中,`Pipeline` 进行初始化,以下是一个简易的实现: + +```python +import torch +from PIL import Image +from typing import Union +from tqdm import tqdm +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit +from ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model + +class NewDiffSynthPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler() + self.text_encoder: XXX_Model = None + self.dit: YYY_Model = None + self.vae: ZZZ_Model = None + self.in_iteration_models = ("dit",) + self.units = [ + NewDiffSynthPipelineUnit_xxx(), + ... + ] + self.model_fn = model_fn_new +``` + +其中包括以下几部分 + +* `scheduler`: 调度器,用于控制推理的迭代公式中的系数,控制每一步的噪声含量。 +* `text_encoder`、`dit`、`vae`: 模型,自 [Latent Diffusion](https://arxiv.org/abs/2112.10752) 被提出以来,这种三段式模型架构已成为主流的 Diffusion 模型架构,但这并不是一成不变的,`Pipeline` 中可添加任意多个模型。 +* `in_iteration_models`: 迭代中模型,这个元组标注了在迭代中会调用哪些模型。 +* `units`: 模型迭代的前处理单元,详见[`units`](#units)。 +* `model_fn`: 迭代中去噪模型的 `forward` 函数,详见[`model_fn`](#model_fn)。 + +> Q: 模型加载并不发生在 `__init__`,为什么这里仍要将每个模型初始化为 `None`? +> +> A: 在这里标注每个模型的类型后,代码编辑器就可以根据每个模型提供代码补全提示,便于后续的开发。 + +## `from_pretrained` + +`from_pretrained` 负责加载所需的模型,让 `Pipeline` 变成可调用的状态。以下是一个简易的实现: + +```python + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + vram_limit: float = None, + ): + # Initialize pipeline + pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("xxx_text_encoder") + pipe.dit = model_pool.fetch_model("yyy_dit") + pipe.vae = model_pool.fetch_model("zzz_vae") + # If necessary, load tokenizers here. + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe +``` + +开发者需要实现其中获取模型的逻辑,对应的模型名称即为[模型接入时填写的模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config) 中的 `"model_name"`。 + +部分模型还需要加载 `tokenizer`,可根据需要在 `from_pretrained` 上添加额外的 `tokenizer_config` 参数并在获取模型后实现这部分。 + +## `__call__` + +`__call__` 实现了整个 Pipeline 的生成过程,以下是常见的生成过程模板,开发者可根据需要在此基础上修改。 + +```python + @torch.no_grad() + def __call__( + self, + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + input_image: Image.Image = None, + denoising_strength: float = 1.0, + height: int = 1328, + width: int = 1328, + seed: int = None, + rand_device: str = "cpu", + num_inference_steps: int = 30, + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps( + num_inference_steps, + denoising_strength=denoising_strength + ) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, + "denoising_strength": denoising_strength, + "height": height, + "width": width, + "seed": seed, + "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id) + if cfg_scale != 1.0: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"], device=self.device) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image +``` + +## `units` + +`units` 包含了所有的前处理过程,例如:宽高检查、提示词编码、初始噪声生成等。在整个模型前处理过程中,数据被抽象为了互斥的三部分,分别存储在对应的字典中: + +* `inputs_shard`: 共享输入,与 [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598)(简称 CFG)无关的参数。 +* `inputs_posi`: Classifier-Free Guidance 的 Positive 侧输入,包含与正向提示词相关的内容。 +* `inputs_nega`: Classifier-Free Guidance 的 Negative 侧输入,包含与负向提示词相关的内容。 + +Pipeline Unit 的实现包括三种:直接模式、CFG 分离模式、接管模式。 + +如果某些计算与 CFG 无关,可采用直接模式,例如 Qwen-Image 的随机噪声初始化: + +```python +class QwenImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} +``` + +如果某些计算与 CFG 有关,需分别处理正向和负向提示词,但两侧的输入参数是相同的,可采用 CFG 分离模式,例如 Qwen-image 的提示词编码: + +```python +class QwenImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + input_params=("edit_image",), + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: + pipe.load_models_to_device(self.onload_model_names) + # Do something + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} +``` + +如果某些计算需要全局的信息,则需要接管模式,例如 Qwen-Image 的实体分区控制: + +```python +class QwenImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): + # Do something + return inputs_shared, inputs_posi, inputs_nega +``` + +以下是 Pipeline Unit 所需的参数配置: + +* `seperate_cfg`: 是否启用 CFG 分离模式 +* `take_over`: 是否启用接管模式 +* `input_params`: 共享输入参数 +* `output_params`: 输出参数 +* `input_params_posi`: Positive 侧输入参数 +* `input_params_nega`: Negative 侧输入参数 +* `onload_model_names`: 需调用的模型组件名 + +在设计 `unit` 时请尽量按照以下原则进行: + +* 缺省兜底:可选功能的 `unit` 输入参数默认为 `None`,而不是 `False` 或其他数值,请对此默认值进行兜底处理。 +* 参数触发:部分 Adapter 模型可能是未被加载的,例如 ControlNet,对应的 `unit` 应当以参数输入是否为 `None` 来控制触发,而不是以模型是否被加载来控制触发。例如当用户输入了 `controlnet_image` 但没有加载 ControlNet 模型时,代码应当给出报错,而不是忽略这些输入参数继续执行。 +* 简洁优先:尽可能使用直接模式,仅当功能无法实现时,使用接管模式。 +* 显存高效:在 `unit` 中调用模型时,请使用 `pipe.load_models_to_device(self.onload_model_names)` 激活对应的模型,请不要调用 `onload_model_names` 之外的其他模型,`unit` 计算完成后,请不要使用 `pipe.load_models_to_device([])` 手动释放显存。 + +> Q: 部分参数并未在推理过程中调用,例如 `output_params`,是否仍有必要配置? +> +> A: 这些参数不会影响推理过程,但会影响一些实验性功能,因此我们建议将其配置好。例如“拆分训练”,我们可以将训练中的前处理离线完成,但部分需要梯度回传的模型计算无法拆分,这些参数用于构建计算图从而推断哪些计算是可以拆分的。 + +## `model_fn` + +`model_fn` 是迭代中的统一 `forward` 接口,对于开源模型生态尚未形成的模型,直接沿用去噪模型的 `forward` 即可,例如: + +```python +def model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs): + return dit(latents, prompt_emb, timestep) +``` + +对于开源生态丰富的模型,`model_fn` 通常包含复杂且混乱的跨模型推理,以 `diffsynth/pipelines/qwen_image.py` 为例,这个函数中实现的额外计算包括:实体分区控制、三种 ControlNet、Gradient Checkpointing 等,开发者在实现这一部分时要格外小心,避免模块功能之间的冲突。 diff --git a/docs/zh/Developer_Guide/Enabling_VRAM_management.md b/docs/zh/Developer_Guide/Enabling_VRAM_management.md new file mode 100644 index 0000000000000000000000000000000000000000..a067f8d2b87736a4619577cd8e4394844cfc6b41 --- /dev/null +++ b/docs/zh/Developer_Guide/Enabling_VRAM_management.md @@ -0,0 +1,228 @@ +# 细粒度显存管理方案 + +本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 + +## 20B 模型需要多少显存? + +以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B,以下代码会加载这一模型并进行推理,需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。 + +```python +from diffsynth.core import load_model +from diffsynth.models.qwen_image_dit import QwenImageDiT +from modelscope import snapshot_download +import torch + +snapshot_download( + model_id="Qwen/Qwen-Image", + local_dir="models/Qwen/Qwen-Image", + allow_file_pattern="transformer/*" +) +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda") +with torch.no_grad(): + output = model(**inputs) +``` + +## 编写细粒度显存管理方案 + +为了编写细粒度的显存管理方案,我们需用 `print(model)` 观察和分析模型结构: + +``` +QwenImageDiT( + (pos_embed): QwenEmbedRope() + (time_text_embed): TimestepEmbeddings( + (time_proj): TemporalTimesteps() + (timestep_embedder): DiffusersCompatibleTimestepProj( + (linear_1): Linear(in_features=256, out_features=3072, bias=True) + (act): SiLU() + (linear_2): Linear(in_features=3072, out_features=3072, bias=True) + ) + ) + (txt_norm): RMSNorm() + (img_in): Linear(in_features=64, out_features=3072, bias=True) + (txt_in): Linear(in_features=3584, out_features=3072, bias=True) + (transformer_blocks): ModuleList( + (0-59): 60 x QwenImageTransformerBlock( + (img_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (attn): QwenDoubleStreamAttention( + (to_q): Linear(in_features=3072, out_features=3072, bias=True) + (to_k): Linear(in_features=3072, out_features=3072, bias=True) + (to_v): Linear(in_features=3072, out_features=3072, bias=True) + (norm_q): RMSNorm() + (norm_k): RMSNorm() + (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True) + (norm_added_q): RMSNorm() + (norm_added_k): RMSNorm() + (to_out): Sequential( + (0): Linear(in_features=3072, out_features=3072, bias=True) + ) + (to_add_out): Linear(in_features=3072, out_features=3072, bias=True) + ) + (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (img_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + (txt_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + ) + ) + (norm_out): AdaLayerNorm( + (linear): Linear(in_features=3072, out_features=6144, bias=True) + (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + ) + (proj_out): Linear(in_features=3072, out_features=64, bias=True) +) +``` + +在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中,`QwenEmbedRope`、`TemporalTimesteps`、`SiLU` 等 Layer 都是不包含参数的,`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear` 和 `RMSNorm`。 + +`diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理: +* `AutoWrappedLinear`: 用于替换 `Linear` 层 +* `AutoWrappedModule`: 用于替换其他任意层 + +编写一个 `module_map`,将模型中的 `Linear` 和 `RMSNorm` 映射到对应的模块上: + +```python +module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, +} +``` + +此外,还需要提供 `vram_config` 与 `vram_limit`,这两个参数在[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。 + +调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device` 为 `cpu`,与 `offload_device` 一致: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu") +enable_vram_management( + model, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`。 + +## Disk Offload + +[Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model( + QwenImageDiT, + model_path, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config={ + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 + +如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。 + +## 写入默认配置 + +为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中,上述模型的配置信息为: + +```python +"diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", +} +``` diff --git a/docs/zh/Developer_Guide/Integrating_Your_Model.md b/docs/zh/Developer_Guide/Integrating_Your_Model.md new file mode 100644 index 0000000000000000000000000000000000000000..cd58cfca33a60d39800f150b6f80ed361b844c05 --- /dev/null +++ b/docs/zh/Developer_Guide/Integrating_Your_Model.md @@ -0,0 +1,186 @@ +# 接入模型结构 + +本文档介绍如何将模型接入到 `DiffSynth-Studio` 框架中,供 `Pipeline` 等模块调用。 + +## Step 1: 集成模型结构代码 + +`DiffSynth-Studio` 中的所有模型结构实现统一在 `diffsynth/models` 中,每个 `.py` 代码文件分别实现一个模型结构,所有模型通过 `diffsynth/models/model_loader.py` 中的 `ModelPool` 来加载。在接入新的模型结构时,请在这个路径下建立新的 `.py` 文件。 + +```shell +diffsynth/models/ +├── general_modules.py +├── model_loader.py +├── qwen_image_controlnet.py +├── qwen_image_dit.py +├── qwen_image_text_encoder.py +├── qwen_image_vae.py +└── ... +``` + +在大多数情况下,我们建议用 `PyTorch` 原生代码的形式集成模型,让模型结构类直接继承 `torch.nn.Module`,例如: + +```python +import torch + +class NewDiffSynthModel(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + x = self.linear(x) + x = self.activation(x) + return x +``` + +如果模型结构的实现中包含额外的依赖,我们强烈建议将其删除,否则这会导致沉重的包依赖问题。在我们现有的模型中,Qwen-Image 的 Blockwise ControlNet 是以这种方式集成的,其代码很轻量,请参考 `diffsynth/models/qwen_image_controlnet.py`。 + +如果模型已被 Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index)、[`diffusers`](https://huggingface.co/docs/diffusers/main/index) 等)集成,我们能够以更简单的方式集成模型: + +
+集成 Huggingface Library 风格模型结构代码 + +这类模型在 Huggingface Library 中的加载方式为: + +```python +from transformers import XXX_Model + +model = XXX_Model.from_pretrained("path_to_your_model") +``` + +`DiffSynth-Studio` 不支持通过 `from_pretrained` 加载模型,因为这与显存管理等功能是冲突的,请将模型结构改写成以下格式: + +```python +import torch + +class DiffSynth_XXX_Model(torch.nn.Module): + def __init__(self): + super().__init__() + from transformers import XXX_Config, XXX_Model + config = XXX_Config(**{ + "architectures": ["XXX_Model"], + "other_configs": "Please copy and paste the other configs here.", + }) + self.model = XXX_Model(config) + + def forward(self, x): + outputs = self.model(x) + return outputs +``` + +其中 `XXX_Config` 为模型对应的 Config 类,例如 `Qwen2_5_VLModel` 的 Config 类为 `Qwen2_5_VLConfig`,可通过查阅其源代码找到。Config 内部的内容通常可以在模型库中的 `config.json` 中找到,`DiffSynth-Studio` 不会读取 `config.json` 文件,因此需要将其中的内容复制粘贴到代码中。 + +在少数情况下,`transformers` 和 `diffusers` 的版本更新会导致部分的模型无法导入,因此如果可能的话,我们仍建议使用 Step 1.1 中的模型集成方式。 + +在我们现有的模型中,Qwen-Image 的 Text Encoder 是以这种方式集成的,其代码很轻量,请参考 `diffsynth/models/qwen_image_text_encoder.py`。 + +
+ +## Step 2: 模型文件格式转换 + +由于开源社区中开发者提供的模型文件格式多种多样,因此我们有时需对模型文件格式进行转换,从而形成格式正确的 [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html),常见于以下几种情况: + +* 模型文件由不同代码库构建,例如 [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 和 [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)。 +* 模型在接入中做了修改,例如 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 的 Text Encoder 在 `diffsynth/models/qwen_image_text_encoder.py` 中增加了 `model.` 前缀。 +* 模型文件包含多个模型,例如 [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) 的 VACE Adapter 和基础 DiT 模型混合存储在同一组模型文件中。 + +在我们的开发理念中,我们希望尽可能尊重模型原作者的意愿。如果对模型文件进行重新封装,例如 [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI),虽然我们可以更方便地调用模型,但流量(模型页面浏览量和下载量等)会被引向他处,模型的原作者也会失去删除模型的权力。因此,我们在框架中增加了 `diffsynth/utils/state_dict_converters` 这一模块,用于在模型加载过程中进行文件格式转换。 + +这部分逻辑是非常简单的,以 Qwen-Image 的 Text Encoder 为例,只需要 10 行代码即可: + +```python +def QwenImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for k in state_dict: + v = state_dict[k] + if k.startswith("visual."): + k = "model." + k + elif k.startswith("model."): + k = k.replace("model.", "model.language_model.") + state_dict_[k] = v + return state_dict_ +``` + +## Step 3: 编写模型 Config + +模型 Config 位于 `diffsynth/configs/model_configs.py`,用于识别模型类型并进行加载。需填入以下字段: + +* `model_hash`:模型文件哈希值,可通过 `hash_model_file` 函数获取,此哈希值仅与模型文件中 state dict 的 keys 和张量 shape 有关,与文件中的其他信息无关。 +* `model_name`: 模型名称,用于给 `Pipeline` 识别所需模型。如果不同结构的模型在 `Pipeline` 中发挥的作用相同,则可以使用相同的 `model_name`。在接入新模型时,只需保证 `model_name` 与现有的其他功能模型不同即可。在 `Pipeline` 的 `from_pretrained` 中通过 `model_name` 获取对应的模型。 +* `model_class`: 模型结构导入路径,指向在 Step 1 中实现的模型结构类,例如 `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`。 +* `state_dict_converter`: 可选参数,如需进行模型文件格式转换,则需填入模型转换逻辑的导入路径,例如 `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`。 +* `extra_kwargs`: 可选参数,如果模型初始化时需传入额外参数,则需要填入这些参数,例如模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 与 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) 都采用了 `diffsynth/models/qwen_image_controlnet.py` 中的 `QwenImageBlockWiseControlNet` 结构,但后者还需额外的配置 `additional_in_dim=4`,因此这部分配置信息需填入 `extra_kwargs` 字段。 + +我们提供了一份代码,以便快速理解模型是如何通过这些配置信息加载的: + +```python +from diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization +from diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder +from diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter +import torch + +model_hash = "8004730443f55db63092006dd9f7110e" +model_name = "qwen_image_text_encoder" +model_class = QwenImageTextEncoder +state_dict_converter = QwenImageTextEncoderStateDictConverter +extra_kwargs = {} + +model_path = [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", +] +if hash_model_file(model_path) == model_hash: + with skip_model_initialization(): + model = model_class(**extra_kwargs) + state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device="cuda") + state_dict = state_dict_converter(state_dict) + model.load_state_dict(state_dict, assign=True) + print("Done!") +``` + +> Q: 上述代码的逻辑看起来很简单,为什么 `DiffSynth-Studio` 中的这部分代码极为复杂? +> +> A: 因为我们提供了激进的显存管理功能,与模型加载逻辑耦合,这导致框架结构的复杂性,我们已尽可能简化暴露给开发者的接口。 + +`diffsynth/configs/model_configs.py` 中的 `model_hash` 不是唯一存在的,同一模型文件中可能存在多个模型。对于这种情况,请使用多个模型 Config 分别加载每个模型,编写相应的 `state_dict_converter` 分离每个模型所需的参数。 + +## Step 4: 检验模型是否能被识别和加载 + +模型接入之后,可通过以下代码验证模型是否能够被正确识别和加载,以下代码会试图将模型加载到内存中: + +```python +from diffsynth.models.model_loader import ModelPool + +model_pool = ModelPool() +model_pool.auto_load_model( + [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", + ], +) +``` + +如果模型能够被识别和加载,则会看到以下输出内容: + +``` +Loading models from: [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +] +Loaded model: { + "model_name": "qwen_image_text_encoder", + "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder", + "extra_kwargs": null +} +``` + +## Step 5: 编写模型显存管理方案 + +`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)。 diff --git a/docs/zh/Developer_Guide/Training_Diffusion_Models.md b/docs/zh/Developer_Guide/Training_Diffusion_Models.md new file mode 100644 index 0000000000000000000000000000000000000000..4313fa14caa3d9849376288744acd3d2d570be15 --- /dev/null +++ b/docs/zh/Developer_Guide/Training_Diffusion_Models.md @@ -0,0 +1,66 @@ +# 接入模型训练 + +在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)并[实现 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md)后,接下来接入模型训练功能。 + +## 训推一致的 Pipeline 改造 + +为了保证训练和推理过程严格的一致性,我们会在训练过程中沿用大部分推理代码,但仍需作出少量改造。 + +首先,在推理过程中添加额外的逻辑,让图生图/视频生视频逻辑根据 `scheduler` 状态进行切换。以 Qwen-Image 为例: + +```python +class QwenImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} +``` + +然后,在 `model_fn` 中启用 Gradient Checkpointing,这将以计算速度为代价,大幅度减少训练所需的显存。这并不是必需的,但我们强烈建议这么做。 + +以 Qwen-Image 为例,修改前: + +```python +text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, +) +``` + +修改后: + +```python +from ..core import gradient_checkpoint_forward + +text, image = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, +) +``` + +## 编写训练脚本 + +`DiffSynth-Studio` 没有对训练框架做严格的封装,而是将脚本内容暴露给开发者,这种方式可以更方便地对训练脚本进行修改,实现额外的功能。开发者可参考现有的训练脚本,例如 `examples/qwen_image/model_training/train.py` 进行修改,从而适配新的模型训练。 diff --git a/docs/zh/Model_Details/FLUX.md b/docs/zh/Model_Details/FLUX.md new file mode 100644 index 0000000000000000000000000000000000000000..71576dc84bdd1ea71c9f370313de8a23e2b818e0 --- /dev/null +++ b/docs/zh/Model_Details/FLUX.md @@ -0,0 +1,201 @@ +# FLUX + +![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d) + +FLUX 是由 Black Forest Labs 开发并开源的图像生成模型系列。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1, +) +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +image = pipe(prompt=prompt, seed=0) +image.save("image.jpg") +``` + +## 模型总览 + +
+ +模型血缘 + +```mermaid +graph LR; + FLUX.1-Series-->black-forest-labs/FLUX.1-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev; + black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series; + FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta; + FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha; + FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler; + black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter; + black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev; + black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview; + black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit; + Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2; + Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2; +``` + +
+ +|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| +|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| +|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| +|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| +|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| +|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| +|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| +|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)| +|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)| + +特殊训练脚本: + +* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) +* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) +* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) +* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) + +## 模型推理 + +模型通过 `FluxImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 + +`FluxImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于 1 的值时启用 CFG。 +* `height`: 图像高度,需保证高度为 16 的倍数。 +* `width`: 图像宽度,需保证宽度为 16 的倍数。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 30。 +* `embedded_guidance`: 嵌入式引导参数,默认值为 3.5。 +* `t5_sequence_length`: T5 文本编码器的序列长度,默认为 512。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 +* `controlnet_inputs`: ControlNet 模型的输入,类型为 `ControlNetInput` 列表。 +* `ipadapter_images`: IP-Adapter 模型的输入图像列表。 +* `ipadapter_scale`: IP-Adapter 模型的引导强度。 +* `infinityou_id_image`: InfiniteYou 模型的输入图像。 +* `infinityou_guidance`: InfiniteYou 模型的引导强度。 +* `kontext_images`: Kontext 模型的输入图像。 +* `eligen_entity_prompts`: EliGen 分区控制的提示词列表。 +* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像列表。 +* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。 +* `eligen_enable_inpaint`: 是否启用 EliGen 分区控制的局部重绘功能。 +* `lora_encoder_inputs`: LoRA 编码器的输入图像列表。 +* `lora_encoder_scale`: LoRA 编码器的引导强度。 +* `step1x_reference_image`: Step1X 模型的参考图像。 +* `flex_inpaint_image`: Flex 模型的待修复图像。 +* `flex_inpaint_mask`: Flex 模型的修复遮罩。 +* `flex_control_image`: Flex 模型的控制图像。 +* `flex_control_strength`: Flex 模型的控制强度。 +* `flex_control_stop`: Flex 模型的控制停止时间步。 +* `nexus_gen_reference_image`: Nexus-Gen 模型的参考图像。 + +如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 + +## 模型训练 + +FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 图像宽高配置(适用于图像生成模型和视频生成模型) + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。 +* FLUX 专有参数 + * `--tokenizer_1_path`: CLIP tokenizer 的路径,留空则自动从远程下载。 + * `--tokenizer_2_path`: T5 tokenizer 的路径,留空则自动从远程下载。 + * `--align_to_opensource_format`: 是否将 LoRA 格式对齐到开源格式,仅适用于 DiT 的 LoRA。 + +我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file diff --git a/docs/zh/Model_Details/FLUX2.md b/docs/zh/Model_Details/FLUX2.md new file mode 100644 index 0000000000000000000000000000000000000000..ad4df2781bf65222ee0ed9393853acb51aeb7ed3 --- /dev/null +++ b/docs/zh/Model_Details/FLUX2.md @@ -0,0 +1,138 @@ +# FLUX.2 + +FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。 + +```python +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image.jpg") +``` + +## 模型总览 + +|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-| +|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)| + +特殊训练脚本: + +* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/) +* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/) +* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/) +* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh) + +## 模型推理 + +模型通过 `Flux2ImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 + +`Flux2ImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于 1 的值时启用 CFG。 +* `height`: 图像高度,需保证高度为 16 的倍数。 +* `width`: 图像宽度,需保证宽度为 16 的倍数。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 30。 +* `embedded_guidance`: 嵌入式引导参数,默认值为 3.5。 +* `t5_sequence_length`: T5 文本编码器的序列长度,默认为 512。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 + +## 模型训练 + +FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 图像宽高配置(适用于图像生成模型和视频生成模型) + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。 +* FLUX.2 专有参数 + * `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。 + +我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file diff --git a/docs/zh/Model_Details/Overview.md b/docs/zh/Model_Details/Overview.md new file mode 100644 index 0000000000000000000000000000000000000000..9c0e6792f3bc870c2d2e454d2fa3147a6dca877e --- /dev/null +++ b/docs/zh/Model_Details/Overview.md @@ -0,0 +1,288 @@ +# 模型目录 + +## Qwen-Image + +文档:[./Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md) + +
+ +效果一览 + +![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924) + +
+ +
+ +快速开始 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from PIL import Image +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, num_inference_steps=40, + # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit +) +image.save("image.jpg") +``` + +
+ +
+ +模型血缘 + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| + +## FLUX 系列 + +文档:[./FLUX.md](/docs/zh/Model_Details/FLUX.md) + +
+ +效果一览 + +![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d) + +
+ +
+ +快速开始 + +```python +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image = pipe(prompt="a cat", seed=0) +image.save("image.jpg") +``` + +
+ +
+ +模型血缘 + +```mermaid +graph LR; + FLUX.1-Series-->black-forest-labs/FLUX.1-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev; + FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev; + black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series; + FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta; + FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha; + FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler; + black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter; + black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev; + black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview; + black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit; + Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit; + black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2; + Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2; +``` + +
+ +|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)| +|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)| +|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| +|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| +|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| +|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| +|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| +|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| +|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)| +|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)| +|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)| + +## Wan 系列 + +文档:[./Wan.md](/docs/zh/Model_Details/Wan.md) + +
+ +效果一览 + +https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 + +
+ +
+ +快速开始 + +```python +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) +``` + +
+ +
+ +模型血缘 + +```mermaid +graph LR; + Wan-Series-->Wan2.1-Series; + Wan-Series-->Wan2.2-Series; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P; + Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P; + Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview; + iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera; + Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1; + Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video; + Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video; + Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B; + Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B; + Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series; + Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera; +``` + +
+ +|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| +|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| +|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| +|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| +|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| +|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md new file mode 100644 index 0000000000000000000000000000000000000000..f2609ac09d94ce5d1ce5612c0e776094df89696a --- /dev/null +++ b/docs/zh/Model_Details/Qwen-Image.md @@ -0,0 +1,192 @@ +# Qwen-Image + +![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924) + +Qwen-Image 是由阿里巴巴通义实验室通义千问团队训练并开源的图像生成模型。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## 模型总览 + +
+ +模型血缘 + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| +|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-| + +特殊训练脚本: + +* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/qwen_image/model_training/special/differential_training/) +* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/qwen_image/model_training/special/fp8_training/) +* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/qwen_image/model_training/special/split_training/) +* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) + +## 模型推理 + +模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 + +`QwenImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4,当设置为 1 时不再生效。 +* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。 +* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。 +* `inpaint_mask`: 图像局部重绘的遮罩图像。 +* `inpaint_blur_size`: 图像局部重绘的边缘柔化宽度。 +* `inpaint_blur_sigma`: 图像局部重绘的边缘柔化强度。 +* `height`: 图像高度,需保证高度为 16 的倍数。 +* `width`: 图像宽度,需保证宽度为 16 的倍数。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 30。 +* `exponential_shift_mu`: 在采样时间步时采用的固定参数,留空则根据图像宽高进行采样。 +* `blockwise_controlnet_inputs`: Blockwise ControlNet 模型的输入。 +* `eligen_entity_prompts`: EliGen 分区控制的提示词。 +* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像。 +* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。 +* `edit_image`: 编辑模型的待编辑图像,支持多张图像。 +* `edit_image_auto_resize`: 是否自动缩放待编辑图像。 +* `edit_rope_interpolation`: 是否在低分辨率编辑图像上启用 ROPE 插值。 +* `context_image`: In-Context Control 的输入图像。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文“模型总览”中的表格。 + +## 模型训练 + +Qwen-Image 系列模型统一通过 [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型 Qwen-Image-Edit 时需要额外参数 `edit_image`,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 图像宽高配置(适用于图像生成模型和视频生成模型) + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。 +* Qwen-Image 专有参数 + * `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。 + * `--processor_path`: processor 的路径,适用于图像编辑模型,留空则自动从远程下载。 + +我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文“模型总览”中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 diff --git a/docs/zh/Model_Details/VACE_Instance_Control.md b/docs/zh/Model_Details/VACE_Instance_Control.md new file mode 100644 index 0000000000000000000000000000000000000000..f091edd25ecd66846989e157d5cbc270578d3818 --- /dev/null +++ b/docs/zh/Model_Details/VACE_Instance_Control.md @@ -0,0 +1,70 @@ +# VACE 实例控制原理 + +VACE(Wan 系列视频模型中的 Video Adapter for Controllable Editing)负责把实例级的视觉约束注入到主干 DiT 去噪器中。它的工作方式与 ControlNet 类似,但它直接处理视频 VAE 潜变量,并以跨层提示(hint)的形式影响特定的 Transformer Block。本说明基于 `diffsynth/pipelines/wan_video.py` 与 `diffsynth/models/wan_video_vace.py` 中的实现。 + +## 控制输入是如何组织的 + +1. **分离被控实例与背景**(`wan_video.py:616-671`) + - 用户提供 `vace_video`(目标动作参考)和 `vace_video_mask`(逐像素的实例遮罩)。 + - 管线首先将视频和遮罩正规化,然后用 VAE 编码器分别得到 “inactive”(背景:`video * (1 - mask)`)和 “reactive”(实例:`video * mask`)潜变量,并在通道维拼接,保留了“这个像素是否由实例驱动”的信息。 + - 遮罩被重排成 8×8 patch,并插值到与潜变量时间分辨率一致,形成额外的 64 个通道,帮助模型知道哪些 token 需要强约束。 + +2. **可选的参考图像补帧** + - `vace_reference_image` 会被打散成逐帧图像并送入同一个 VAE,输出与视频相同的 latent 格式。 + - 这些参考帧被拼在时间维的开头,并配套 0 值遮罩,使得模型在最开始的几个时间步直接看到参考姿态或风格。 + +3. **统一的 VACE 上下文张量** + - 最终 `vace_video_latents`(inactive+reactive)与 `vace_mask_latents` 被 concat 成 `vace_context`,形状约为 `[B, 96, T/4, H/8, W/8]`(以默认 VAE + patch size 为例)。 + - 如果没有传入任何控制项,则返回 `None` 并跳过整个模块。 + +## VACE 模块内部结构 + +### 3D Patch Embedding 对齐 token(`wan_video_vace.py:23-38`) + +`VaceWanModel` 通过一个 `Conv3d`(`kernel_size=(1,2,2)`)对 `vace_context` 进行 patch 化,使输出 token 的时空分布与主干 DiT 的 latent token(`x`)严格一致。这一步确保后续可以逐 token 相加而无需插值。 + +### 与主干 DiT 同步的注意力块(`wan_video_vace.py:5-22, 40-64`) + +- 每个 `VaceWanAttentionBlock` 继承自 `DiTBlock`,在第一个 block 内将 patch token 与当前的 DiT 激活 `x` 相加(`self.before_proj(c) + x`),以便直接读取主干的上下文。 +- 每次前向都会保存一个 `after_proj` 过的 skip 分支,并把它压入堆栈。下一层会弹出上一层的输出继续计算。 +- 完成所有层后,`torch.unbind(c)[:-1]` 得到的就是与 `vace_layers` 一一对应的 hint 序列,每个 hint 的形状与 `x` 完全相同。 + +这种堆栈式设计允许 VACE 和主干 DiT 保持“层级对齐”:第 *i* 个 hint 只依赖第 *i* 层之前看到的特征,从而更像一个逐层附加的 Adapter,而不是单次生成所有控制信号。 + +### 计算和显存开销控制 + +`VaceWanModel.forward()` 支持 `use_gradient_checkpointing` 与 `use_gradient_checkpointing_offload` 参数,在长视频或 14B 模型上可以显著降低峰值显存。 + +## Hint 如何注入 Wan DiT + +1. **先在时间步开始时生成所有 hint**(`wan_video.py:1299-1306`): + + ```python + if vace_context is not None: + vace_hints = vace(x, vace_context, context, t_mod, freqs, ...) + ``` + +2. **在指定的 Transformer Block 中相加**(`wan_video.py:1336-1370`): + - `VaceWanModel` 暴露了 `vace_layers`(默认 `(0, 2, 4, …, 28)`),通过 `vace_layers_mapping` 把 block id 映射到 hint 序号。 + - 当循环到对应的 block 时,执行 `x = x + current_vace_hint * vace_scale`。 + - 若开启 Unified Sequence Parallel,还会对 hint 做同样的 `chunk` 与 `pad`,保证分布式场景正常。 + +3. **`vace_scale` 控制注入强度** + - 小于 1 会弱化控制,趋近 0 时退化成原始 DiT;大于 1 会让生成更严格地跟随输入区域,但也可能带来 artifact。 + - 该参数可以在推理时自由调整,属于最直接的“实例控制力度”调节钮。 + +## 整体实例控制流程总结 + +1. 构造实例条件:用户提供控制视频、遮罩以及可选参考图像。遮罩把 latent 分成实例(reactive)与背景(inactive),并提供额外 mask token 告知模型哪里需要被“硬约束”。 +2. VACE 通过 3D patch embedding 和与主干对齐的 DiT block,生成一组跨层 hint。由于每层 hint 看到的都是局部且 mask-aware 的 token,它可以理解“这块区域属于某个实例,应按输入动作移动”。 +3. 主干 Wan DiT 在每个指定层把 hint 加到自身激活上,相当于在不同深度注入“我要让这个实例保持外观 / 跟随这段骨架”的约束。 +4. `vace_scale`、`vace_layers`、遮罩形状共同决定了控制强度与范围,可通过只提供局部 mask 来达到实例级编辑、背景保护、局部跟随等效果。 + +## 实践建议 + +- 生成实例遮罩时尽量与输入视频逐帧对齐,否则 inactive/active 混淆会削弱控制效果。 +- 当只需锁定某个主体的外观,可以提供静态 `vace_reference_image` 而把 `vace_video` 设为全零,通过遮罩标记该主体所在区域;这样模型会在初始帧看到清晰参照。 +- 如果需要更强的约束,可以在配置文件中把 `vace_layers` 扩展到更深的层数,让 hint 影响更多语义层,但要注意显存和可能的 overfit。 +- 在多机推理时记得同步 `vace_scale`,并保持遮罩尺寸能被 8 整除(VAE 下采样 + 8×8 patch),否则会被 `rearrange` 抛错。 + +通过以上机制,VACE 在 Wan 视频模型中提供了一个“实例级 ControlNet”:它不直接更改像素,而是以多层特征 hint 的方式牵引主干去噪过程,让指定的实例在时空上维持用户期望的动作与外观。 diff --git a/docs/zh/Model_Details/Wan.md b/docs/zh/Model_Details/Wan.md new file mode 100644 index 0000000000000000000000000000000000000000..b8c3032d4a064c33d3ff8b144ebdca0de550e971 --- /dev/null +++ b/docs/zh/Model_Details/Wan.md @@ -0,0 +1,253 @@ +# Wan + +https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 + +Wan 是由阿里巴巴通义实验室通义万相团队开发的视频生成模型系列。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。 + +```python +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) +``` + +## 模型总览 + +
+ +模型血缘 + +```mermaid +graph LR; + Wan-Series-->Wan2.1-Series; + Wan-Series-->Wan2.2-Series; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B; + Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P; + Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P; + Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview; + iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP; + Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP; + Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control; + Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP; + Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera; + Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP; + Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera; + Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1; + Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video; + Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video; + Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B; + Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B; + Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B; + Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B; + Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series; + Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control; + Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera; +``` + +
+ +|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| +|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)| +|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)| +|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| +|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| +|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| + +* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/wanvideo/model_training/special/fp8_training/) +* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/wanvideo/model_training/special/split_training/) +* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/wanvideo/model_training/special/direct_distill/) + +## 模型推理 + +模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 + +`WanVideoPipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述视频中出现的内容。 +* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 5,当设置为 1 时不再生效。 +* `input_image`: 输入图像,用于图生视频,该参数与 `denoising_strength` 配合使用。 +* `end_image`: 结束图像,用于首尾帧生成视频。 +* `input_video`: 输入视频,用于视频到视频生成,该参数与 `denoising_strength` 配合使用。 +* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成视频与输入视频相似;当数值接近 1 时,生成视频与输入视频相差更大。 +* `control_video`: 控制视频,用于控制视频生成过程。 +* `reference_image`: 参考图像,用于保持生成视频中某些特征的一致性。 +* `camera_control_direction`: 相机控制方向,可选值为 `"Left"`, `"Right"`, `"Up"`, `"Down"`, `"LeftUp"`, `"LeftDown"`, `"RightUp"`, `"RightDown"`。 +* `camera_control_speed`: 相机控制速度,默认值为 1/54。 +* `vace_video`: VACE 控制视频。 +* `vace_video_mask`: VACE 控制视频遮罩。 +* `vace_reference_image`: VACE 参考图像。 +* `vace_scale`: VACE 控制强度,默认值为 1.0。 +* `animate_pose_video`: `animate` 模型姿态视频。 +* `animate_face_video`: `animate` 模型面部视频。 +* `animate_inpaint_video`: `animate` 模型局部编辑视频。 +* `animate_mask_video`: `animate` 模型遮罩视频。 +* `vap_video`: `video-as-prompt` 的输入视频。 +* `vap_prompt`: `video-as-prompt` 的文本描述。 +* `negative_vap_prompt`: `video-as-prompt` 的负向文本描述。 +* `input_audio`: 输入音频,用于语音到视频生成。 +* `audio_embeds`: 音频嵌入向量。 +* `audio_sample_rate`: 音频采样率,默认值为 16000。 +* `s2v_pose_video`: S2V 模型的姿态视频。 +* `motion_video`: S2V 模型的运动视频。 +* `height`: 视频高度,需保证高度为 16 的倍数。 +* `width`: 视频宽度,需保证宽度为 16 的倍数。 +* `num_frames`: 视频帧数,默认值为 81,需保证为 4 的倍数 + 1。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 50。 +* `motion_bucket_id`: 运动控制参数,数值越大,运动幅度越大。 +* `longcat_video`: LongCat 输入视频。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 `(30, 52)`,仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 `(15, 26)`,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `switch_DiT_boundary`: 切换DiT模型的时间边界,默认值为 0.875。 +* `sigma_shift`: 时间步偏移参数,默认值为 5.0。 +* `sliding_window_size`: 滑动窗口大小。 +* `sliding_window_stride`: 滑动窗口步长。 +* `tea_cache_l1_thresh`: TeaCache 的 L1 阈值。 +* `tea_cache_model_id`: TeaCache 使用的模型 ID。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 + +## 模型训练 + +Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型时需要额外参数,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 视频宽高配置 + * `--height`: 视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的视频帧都会被缩小,分辨率小于这个数值的视频帧保持不变。 + * `--num_frames`: 视频的帧数。 +* Wan 系列专有参数 + * `--tokenizer_path`: tokenizer 的路径,适用于文生视频模型,留空则自动从远程下载。 + * `--audio_processor_path`: 音频处理器的路径,适用于语音到视频模型,留空则自动从远程下载。 + +我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 \ No newline at end of file diff --git a/docs/zh/Model_Details/Wan_Instance_StateMachine_T5_BBox.md b/docs/zh/Model_Details/Wan_Instance_StateMachine_T5_BBox.md new file mode 100644 index 0000000000000000000000000000000000000000..2b1c45e921bdc6a921429252ab6ade4178d203c4 --- /dev/null +++ b/docs/zh/Model_Details/Wan_Instance_StateMachine_T5_BBox.md @@ -0,0 +1,71 @@ +# Wan 实例 StateMachine(T5 语义 + BBox Mask)方案草案 + +基于 `diffsynth/models/wan_video_dit_statemachine_1.py` 的实验版 StateMachine,这里整理一套“实例 ID + 文本语义 + BBox”设计,核心目标是让实例控制与文本语义对齐:` is ` 用 T5 编码,遮罩用 bbox 自动生成,尽量减少手工语义对齐成本。 + +## 设计动机 +- 类别和状态本身有语义,embedding 表应该来自语言模型而不是离散 ID。 +- 文本提示/负提示与实例控制共享语义空间,避免“prompt 说羊在吃草,但实例 state_emb 不懂吃草”。 +- 用 bbox 代替像素级 mask,降低标注/推理成本,并便于跨帧插值或填充。 + +## 输入约定(bbox only) +- `instance_id`:仅用于区分同一类别的不同个体,可继续用可训练的 `nn.Embedding`。 +- `class_text`:自由文本(如 `"sheep"`)。 +- `state_text`:自由文本(如 `"eating"` 或 `"open"`)。 +- `bbox`:形状 `(B, N, F, 4)`,xyxy 像素坐标;若为 `(B, N, 4)` 则对所有帧广播。可选 `bbox_mask` 表示实例在某帧是否存在。 +- `class_state_text_embeds` / `instance_text_input_ids`:如果想在模型内部跑 T5,可传 token ids;否则可传已经编码好的 ` is ` 向量。 + +## 文本编码(T5) +1. 构造语义短语:`"{class_text} is {state_text}"`(必要时加少量 prompt 工程,如 `"a {class_text} that is {state_text}"`)。 +2. 通过 T5-encoder 获得 token hidden states;使用 `[EOS]` 或 mean pool 得到 `(B, N, D_t5)`。 +3. 线性映射到 DiT 维度:`Linear(D_t5 -> dim)`,再 LayerNorm。 +4. 与 `instance_id_emb` 融合:`fusion([t5_sem, inst_id_emb]) -> inst_token`。可继续保留 gate 以便初始不破坏生成。 + +对齐策略:T5 编码器最好与主模型文本编码共享词表或直接复用 T5 作为主 prompt 编码器,保证同一语义空间;否则需单独训练配准。 + +## BBox → Patch Mask +1. 将 bbox 从像素坐标映射到 patch grid: + - `H_p = H / patch_size_h`, `W_p = W / patch_size_w`。 + - 对每帧 bbox 取整到 patch 单位:`x0//ps_w, x1//ps_w, ...`,保证 `x1>=x0+1`。 +2. 构建 `(B, N, F, H_p, W_p)` 二值 mask,前景=1;如果只给 `(B, N, 4)`,则对每帧复用。 +3. 展平为 `(B, N, L)` 供 `MaskGuidedCrossAttention` 使用(`L = F_p * H_p * W_p`)。 +4. 可选:对 bbox 边缘做软扩张(dilation 1~2 个 patch)以缓冲量化误差。 + +## 前向路径修改要点 +- `InstanceFeatureExtractor`:新增 `text_dim` & `text_to_class_state`,直接将 ` is ` 向量拆成类/状态两份,再与 `instance_id_emb` 融合;若未提供文本则回退到 `class_id/state_id` 路径。 +- `forward` 输入:新增 + - `instance_text_input_ids`/`instance_text_attention_mask`(内部跑 T5); + - `instance_class_state_text_embeds`(外部已编码的 T5 向量); + - `instance_bboxes` / `instance_bbox_mask`(bbox -> patch mask)。 +- `process_masks`:增加 bbox 分支,将 xyxy 投影到 patch grid,支持帧级 mask;仍兼容旧的像素级 mask。 +- 其余链路保持不变:实例 tokens 仍在每个 DiTBlock 通过 `instance_tokens/instance_masks` 触发一次 mask-guided cross-attention。 + +## 模型结构图(简化) +```mermaid +flowchart LR + A[Class text] -->|模板 \" is \"| T5[T5 Encoder] + S[State text] -->|同上| T5 + T5 --> P[Pool & Linear to dim] + ID[Instance ID emb] --> FUSE + P --> FUSE[Fusion MLP -> Instance Tokens] + + VID[VAE Latent Video] --> PATCH[3D Patchify] + BBOX[BBox / Mask] --> MASK[BBox->Patch Mask] + + PATCH --> DIT[Wan DiT Blocks] + FUSE --> DIT + MASK --> DIT + DIT --> HEAD[Head / Unpatchify] + HEAD --> OUT[Pred noise] +``` + +## 训练与推理建议 +- 数据:需要 `` 标签;若 bbox 稀疏,可用检测/分割模型自动标注。 +- 文本正负提示:确保主 prompt 与实例短语共享编码器,或者在损失中加入对齐项(如对同一语义的 CLIP/T5 空间蒸馏)。 +- 稳定性:保持 `gate` 零初始化,先小步微调,逐步解冻 T5/融合层。 +- 多实例帧缺失:用 `bbox_mask` 将缺失帧的 mask 设为 0,避免实例 token 影响不存在的帧。 +- 性能:bbox→mask 是 O(N*F*H_p*W_p) 的简单填充,可在 dataloader 端完成并缓存。 + +## 与现有实现的差异(对 `wan_video_dit_statemachine_1.py` 的映射) +- 语义来源:从 `class_id/state_id` embedding 切换为 T5 文本编码;`instance_id` embedding 仍保留用于区分个体。 +- mask 生成:由“像素级 mask 下采样”改为“bbox 投影到 patch grid”。 +- 其余控制逻辑(逐层 mask-guided cross-attention、gate、gradient checkpoint)可复用现有代码。 diff --git a/docs/zh/Model_Details/Z-Image.md b/docs/zh/Model_Details/Z-Image.md new file mode 100644 index 0000000000000000000000000000000000000000..ad2818ec60b2bcc0ed409c9ab534b4ce23252644 --- /dev/null +++ b/docs/zh/Model_Details/Z-Image.md @@ -0,0 +1,141 @@ +# Z-Image + +Z-Image 是由阿里巴巴通义实验室多模态交互团队训练并开源的图像生成模型。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。 + +## 快速开始 + +运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。 + +```python +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image.jpg") +``` + +## 模型总览 + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)| + +特殊训练脚本: + +* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/z_image/model_training/special/differential_training/) +* 轨迹模仿蒸馏训练(实验性功能):[code](/examples/z_image/model_training/special/trajectory_imitation/) + +## 模型推理 + +模型通过 `ZImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 + +`ZImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1。 +* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。 +* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。 +* `height`: 图像高度,需保证高度为 16 的倍数。 +* `width`: 图像宽度,需保证宽度为 16 的倍数。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 8。 + +如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 + +## 模型训练 + +Z-Image 系列模型统一通过 [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py) 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型时需要额外参数,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 图像宽高配置(适用于图像生成模型和视频生成模型) + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。 +* Z-Image 专有参数 + * `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。 + +我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。 + +训练技巧: + +* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 是一个蒸馏加速的模型,因此直接训练将会迅速让模型失去加速能力,以“加速配置”(`num_inference_steps=8`,`cfg_scale=1`)推理的效果变差,以“无加速配置”(`num_inference_steps=30`,`cfg_scale=2`)推理的效果变好。可采用以下方案训练和推理: + * 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + 无加速配置推理 + * 差分 LoRA 训练([code](/examples/z_image/model_training/special/differential_training/)) + 加速配置推理 + * 差分 LoRA 训练中需加载一个额外的 LoRA,例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter) + * 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 轨迹模仿蒸馏训练([code](/examples/z_image/model_training/special/trajectory_imitation/))+ 加速配置推理 + * 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 推理时加载蒸馏加速 LoRA([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + 加速配置推理 diff --git a/docs/zh/Pipeline_Usage/Environment_Variables.md b/docs/zh/Pipeline_Usage/Environment_Variables.md new file mode 100644 index 0000000000000000000000000000000000000000..9c96fccf923c5451961929b16b3faa2d2bcb8456 --- /dev/null +++ b/docs/zh/Pipeline_Usage/Environment_Variables.md @@ -0,0 +1,39 @@ +# 环境变量 + +`DiffSynth-Studio` 可通过环境变量控制一些设置。 + +在 `Python` 代码中,可以使用 `os.environ` 设置环境变量。请注意,环境变量需在 `import diffsynth` 前设置。 + +```python +import os +os.environ["DIFFSYNTH_MODEL_BASE_PATH"] = "./path_to_my_models" +import diffsynth +``` + +在 Linux 操作系统上,也可在命令行临时设置环境变量: + +```shell +DIFFSYNTH_MODEL_BASE_PATH="./path_to_my_models" python xxx.py +``` + +以下是 `DiffSynth-Studio` 所支持的环境变量。 + +## `DIFFSYNTH_SKIP_DOWNLOAD` + +是否跳过模型下载。可设置为 `True`、`true`、`False`、`false`,若 `ModelConfig` 中没有设置 `skip_download`,则会根据这一环境变量决定是否跳过模型下载。 + +## `DIFFSYNTH_MODEL_BASE_PATH` + +模型下载根目录。可设置为任意本地路径,若 `ModelConfig` 中没有设置 `local_model_path`,则会将模型文件下载到这一环境变量指向的路径。若两者都未设置,则会将模型文件下载到 `./models`。 + +## `DIFFSYNTH_ATTENTION_IMPLEMENTATION` + +注意力机制实现的方式,可以设置为 `flash_attention_3`、`flash_attention_2`、`sage_attention`、`xformers`、`torch`。详见 [`./core/attention.md`](/docs/zh/API_Reference/core/attention.md). + +## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE` + +硬盘直连中的 Buffer 大小,默认是 1B(1000000000),数值越大,占用内存越大,速度越快。 + +## `DIFFSYNTH_DOWNLOAD_SOURCE` + +远程模型下载源,可设置为 `modelscope` 或 `huggingface`,控制模型下载的来源,默认值为 `modelscope`。 diff --git a/docs/zh/Pipeline_Usage/Model_Inference.md b/docs/zh/Pipeline_Usage/Model_Inference.md new file mode 100644 index 0000000000000000000000000000000000000000..75a1ed8ee0dc433919fbc638024de775b768f8ad --- /dev/null +++ b/docs/zh/Pipeline_Usage/Model_Inference.md @@ -0,0 +1,105 @@ +# 模型推理 + +本文档以 Qwen-Image 模型为例,介绍如何使用 `DiffSynth-Studio` 进行模型推理。 + +## 加载模型 + +模型通过 `from_pretrained` 加载: + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +``` + +其中 `torch_dtype` 和 `device` 是计算精度和计算设备(不是模型的精度和设备)。`model_configs` 可通过多种方式配置模型路径,关于本项目内部是如何加载模型的,请参考 [`diffsynth.core.loader`](/docs/zh/API_Reference/core/loader.md)。 + +
+ +从远程下载模型并加载 + +> `DiffSynth-Studio` 默认从[魔搭社区](https://www.modelscope.cn/)下载并加载模型,需填写 `model_id` 和 `origin_file_pattern`,例如 +> +> ```python +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), +> ``` +> +> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 + +
+ +
+ +从本地文件路径加载模型 + +> 填写 `path`,例如 +> +> ```python +> ModelConfig(path="models/xxx.safetensors") +> ``` +> +> 对于从多个文件加载的模型,使用列表即可,例如 +> +> ```python +> ModelConfig(path=[ +> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", +> ]) +> ``` + +
+ +默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 + +```shell +import os +os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True" +import diffsynth +``` + +如需从 [HuggingFace](https://huggingface.co/) 下载模型,请将[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) 设置为 `huggingface`。 + +```shell +import os +os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface" +import diffsynth +``` + +## 启动推理 + +输入提示词,即可启动推理过程,生成一张图片。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +每个模型 `Pipeline` 的输入参数不同,请参考各模型的文档。 + +如果模型参数量太大,导致显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 diff --git a/docs/zh/Pipeline_Usage/Model_Training.md b/docs/zh/Pipeline_Usage/Model_Training.md new file mode 100644 index 0000000000000000000000000000000000000000..208246e396a992011a375b61989e85ef6aed683b --- /dev/null +++ b/docs/zh/Pipeline_Usage/Model_Training.md @@ -0,0 +1,245 @@ +# 模型训练 + +本文档介绍如何使用 `DiffSynth-Studio` 进行模型训练。 + +## 脚本参数 + +训练脚本通常包含以下参数: + +* 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 +* 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型 Qwen-Image-Edit 时需要额外参数 `edit_image`,以 `,` 分隔。 + * `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。 +* 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。 + * `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。 + * `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。 +* 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。 +* LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 +* 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 +* 图像宽高配置(适用于图像生成模型和视频生成模型) + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。 + +部分模型的训练脚本还包含额外的参数,详见各模型的文档。 + +## 准备数据集 + +`DiffSynth-Studio` 采用通用数据集格式,数据集包含一系列数据文件(图像、视频等),以及标注元数据的文件,我们建议您这样组织数据集文件: + +``` +data/example_image_dataset/ +├── metadata.csv +├── image_1.jpg +└── image_2.jpg +``` + +其中 `image_1.jpg`、`image_2.jpg` 为训练用图像数据,`metadata.csv` 为元数据列表,例如 + +``` +image,prompt +image_1.jpg,"a dog" +image_2.jpg,"a cat" +``` + +我们构建了样例数据集,以方便您进行测试。了解通用数据集架构是如何实现的,请参考 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。 + +
+ +样例图像数据集 + +> ```shell +> modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +> ``` +> +> 适用于 Qwen-Image、FLUX 等图像生成模型的训练。 + +
+ +
+ +样例视频数据集 + +> ```shell +> modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset +> ``` +> +> 适用于 Wan 等视频生成模型的训练。 + +
+ +## 加载模型 + +类似于[推理时的模型加载](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型),我们支持多种方式配置模型路径,两种方式是可以混用的。 + +
+ +从远程下载模型并加载 + +> 如果在推理时我们通过以下设置加载模型 +> +> ```python +> model_configs=[ +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), +> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), +> ] +> ``` +> +> 那么在训练时,填入以下参数即可加载对应的模型。 +> +> ```shell +> --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" +> ``` +> +> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。 +> +> 默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。 + +
+ +
+ +从本地文件路径加载模型 + +> 如果从本地文件加载模型,例如推理时 +> +> ```python +> model_configs=[ +> ModelConfig([ +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors" +> ]), +> ModelConfig([ +> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +> ]), +> ModelConfig("models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors") +> ] +> ``` +> +> 那么训练时需设置为 +> +> ```shell +> --model_paths '[ +> [ +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors", +> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors" +> ], +> [ +> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", +> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +> ], +> "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors" +> ]' \ +> ``` +> +> 请注意,`--model_paths` 是 json 格式,其中不能出现多余的 `,`,否则无法被正常解析。 + +
+ +## 设置可训练模块 + +训练框架支持任意模型的训练,以 Qwen-Image 为例,若全量训练其中的 DiT 模型,则需设置为 + +```shell +--trainable_models "dit" +``` + +若训练 DiT 模型的 LoRA,则需设置 + +```shell +--lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32 +``` + +我们希望给技术探索留下足够的发挥空间,因此框架支持同时训练任意多个模块,例如同时训练 text encoder、controlnet,以及 DiT 的 LoRA: + +```shell +--trainable_models "text_encoder,controlnet" --lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32 +``` + +此外,由于训练脚本中加载了多个模块(text encoder、dit、vae 等),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`。如果多个模块同时训练,则需开发者在训练完成后自行编写代码拆分模型文件中的 state dict。 + +## 启动训练程序 + +训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,训练命令按照如下格式编写: + +```shell +accelerate launch xxx/train.py \ + --xxx yyy \ + --xxxx yyyy +``` + +我们为每个模型编写了预置的训练脚本,详见各模型的文档。 + +默认情况下,`accelerate` 会按照 `~/.cache/huggingface/accelerate/default_config.yaml` 的配置进行训练,使用 `accelerate config` 可在终端交互式地配置,包括多 GPU 训练、[`DeepSpeed`](https://www.deepspeed.ai/) 等。 + +我们为部分模型提供了推荐的 `accelerate` 配置文件,可通过 `--config_file` 设置,例如 Qwen-Image 模型的全量训练: + +```shell +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters +``` + +## 训练注意事项 + +* 数据集的元数据除 `csv` 格式外,还支持 `json`、`jsonl` 格式,关于如何选择最佳的元数据格式,请参考[](/docs/zh/API_Reference/core/data.md#元数据) +* 通常训练效果与训练步数强相关,与 epoch 数量弱相关,因此我们更推荐使用参数 `--save_steps` 按训练步数间隔来保存模型文件。 +* 当数据量 * `dataset_repeat` 超过 $10^9$ 时,我们观测到数据集的速度明显变慢,这似乎是 `PyTorch` 的 bug,我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。 +* 学习率 `--learning_rate` 在 LoRA 训练中建议设置为 `1e-4`,在全量训练中建议设置为 `1e-5`。 +* 训练框架不支持 batch size > 1,原因是复杂的,详见 [Q&A: 为什么训练框架不支持 batch size > 1?](/docs/zh/QA.md#为什么训练框架不支持-batch-size--1) +* 少数模型包含冗余参数,例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分,在训练这些模型时,需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑,我们不打算删除这些冗余参数。 +* Diffusion 模型的损失函数值与实际效果的关系不大,因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值,边训边测,直至效果收敛后手动关闭训练程序。 +* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](/docs/zh/API_Reference/core/gradient.md)。 diff --git a/docs/zh/Pipeline_Usage/Setup.md b/docs/zh/Pipeline_Usage/Setup.md new file mode 100644 index 0000000000000000000000000000000000000000..715467c275fa70d7d47384b020e5861b290a002e --- /dev/null +++ b/docs/zh/Pipeline_Usage/Setup.md @@ -0,0 +1,21 @@ +# 安装依赖 + +从源码安装(推荐): + +``` +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装) + +``` +pip install diffsynth +``` + +如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档: + +* [torch](https://pytorch.org/get-started/locally/) +* [sentencepiece](https://github.com/google/sentencepiece) +* [cmake](https://cmake.org) diff --git a/docs/zh/Pipeline_Usage/VRAM_management.md b/docs/zh/Pipeline_Usage/VRAM_management.md new file mode 100644 index 0000000000000000000000000000000000000000..2235c12c26972534a3e81c46a7f59ee4824979af --- /dev/null +++ b/docs/zh/Pipeline_Usage/VRAM_management.md @@ -0,0 +1,206 @@ +# 显存管理 + +显存管理是 `DiffSynth-Studio` 的特色功能,能够让低显存的 GPU 能够运行参数量巨大的模型推理。本文档以 Qwen-Image 为例,介绍显存管理方案的使用。 + +## 基础推理 + +以下代码中没有启用任何显存管理,显存占用 56G,作为参考。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## CPU Offload + +由于模型 `Pipeline` 包括多个组件,这些组件并非同时调用的,因此我们可以在某些组件不需要参与计算时将其移至内存,减少显存占用,以下代码可以实现这一逻辑,显存占用 40G。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## FP8 量化 + +在 CPU Offload 的基础上,我们进一步启用 FP8 量化来减少显存需求,以下代码可以令模型参数以 FP8 精度存储在显存中,并在推理时临时转为 BF16 精度计算,显存占用 21G。但这种量化方案有微小的图像质量下降问题。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cuda", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +> Q: 为什么要在推理时临时转为 BF16 精度,而不是以 FP8 精度计算? +> +> A: FP8 的原生计算仅在 Hopper 架构的 GPU(例如 H20)支持,且计算误差很大,我们目前暂不开放 FP8 精度计算。目前的 FP8 量化仅能减少显存占用,不会提高计算速度。 + +## 动态显存管理 + +在 CPU Offload 中,我们对模型组件进行控制,事实上,我们支持做到 Layer 级别的 Offload,将一个模型拆分为多个 Layer,令一部分常驻显存,令一部分存储在内存中按需移至显存计算。这一功能需要模型开发者针对每个模型提供详细的显存管理方案,相关配置在 `diffsynth/configs/vram_management_module_maps.py` 中。 + +通过在 `Pipeline` 中增加 `vram_limit` 参数,框架可以自动感知设备的剩余显存并决定如何拆分模型到显存和内存中。`vram_limit` 越小,占用显存越少,速度越慢。 +* `vram_limit=None` 时,即默认状态,框架认为显存无限,动态显存管理是不启用的 +* `vram_limit=10` 时,框架会在显存占用超过 10G 之后限制模型,将超出的部分移至内存中存储。 +* `vram_limit=0` 时,框架会尽全力减少显存占用,所有模型参数都存储在内存中,仅在必要时移至显存计算 + +在显存不足以运行模型推理的情况下,框架会试图超出 `vram_limit` 的限制从而让模型推理运行下去,因此显存管理框架并不能总是保证占用的显存小于 `vram_limit`,我们建议将其设置为略小于实际可用显存的数值,例如 GPU 显存为 16G 时,设置为 `vram_limit=15.5`。`PyTorch` 中可用 `torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3)` 获取 GPU 的显存。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## Disk Offload + +在更为极端的情况下,当内存也不足以存储整个模型时,Disk Offload 功能可以让模型参数惰性加载,即,模型中的每个 Layer 仅在调用 forward 时才会从硬盘中读取相应的参数。启用这一功能时,我们建议使用高速的 SSD 硬盘。 + +Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=10, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## 更多使用方式 + +`vram_config` 中的信息可自行填写,例如不开 FP8 量化的 Disk Offload: + +```python +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +``` + +具体地,显存管理模块会将模型的 Layer 分为以下四种状态: + +* Offload:短期内不调用这个模型,这个状态由 `Pipeline` 控制切换 +* Onload:接下来随时要调用这个模型,这个状态由 `Pipeline` 控制切换 +* Preparing:Onload 和 Computation 的中间状态,在显存允许的前提下的暂存状态,这个状态由显存管理机制控制切换,当且仅当【vram_limit 设置为无限制】或【vram_limit 已设置且有空余显存】时会进入这一状态 +* Computation:模型正在计算过程中,这个状态由显存管理机制控制切换,仅在 `forward` 中临时进入 + +如果你是模型开发者,希望自行控制某个模型的显存管理粒度,请参考[../Developer_Guide/Enabling_VRAM_management.md](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)。 + +## 最佳实践 + +* 显存足够 -> 使用[基础推理](#基础推理) +* 显存不足 + * 内存足够 -> 使用[动态显存管理](#动态显存管理) + * 内存不足 -> 使用[Disk Offload](#disk-offload) diff --git a/docs/zh/QA.md b/docs/zh/QA.md new file mode 100644 index 0000000000000000000000000000000000000000..b1d55df801771c5789344733af27588f81b5c665 --- /dev/null +++ b/docs/zh/QA.md @@ -0,0 +1,28 @@ +# 常见问题 + +## 为什么训练框架不支持 batch size > 1? + +* **更大的 batch size 已无法实现显著加速**:由于 flash attention 等加速技术已经充分提高了 GPU 的利用率,因此更大的 batch size 只会带来更大的显存占用,无法带来显著加速。在 Stable Diffusion 1.5 这类小模型上的经验已不再适用于最新的大模型。 +* **更大的 batch size 可以用其他方案实现**:多 GPU 训练和 Gradient Accumulation 都可以在数学意义上等价地实现更大的 batch size。 +* **更大的 batch size 与框架的通用性设计相悖**:我们希望构建通用的训练框架,大量模型无法适配更大的 batch size,例如不同长度的文本编码、不同分辨率的图像等,都是无法合并为更大的 batch 的。 + +## 为什么不删除某些模型中的冗余参数? + +在部分模型中,模型存在冗余参数,例如 Qwen-Image 的 DiT 模型最后一层的文本部分,这部分参数不会参与任何计算,这是模型开发者留下的小 bug。直接将其设置为可训练时还会在多 GPU 训练中出现报错。 + +为了与开源社区中其他模型保持兼容性,我们决定保留这些参数。这些冗余参数在多 GPU 训练中可以通过 `--find_unused_parameters` 参数避免报错。 + +## 为什么 FP8 量化没有任何加速效果? + +原生 FP8 计算需要依赖 Hopper 架构的 GPU,同时在计算精度上有较大误差,目前仍然是不成熟的技术,因此本项目不支持原生 FP8 计算。 + +显存管理中的 FP8 计算是指将模型参数以 FP8 精度存储在内存或显存中,在需要计算时临时转换为其他精度,因此仅能减少显存占用,没有加速效果。 + +## 为什么训练框架不支持原生 FP8 精度训练? + +即使硬件条件允许,我们目前也没有任何支持原生 FP8 精度训练的规划。 + +* 目前原生 FP8 精度训练的主要挑战是梯度爆炸导致的精度溢出,为了保证训练的稳定性,需针对性地重新设计模型结构,然而目前还没有任何模型开发者愿意这么做。 +* 此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU,则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。 + +因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。 diff --git a/docs/zh/README.md b/docs/zh/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b0b2310fc2cf45aaba8dd0b7851c1693b01990d4 --- /dev/null +++ b/docs/zh/README.md @@ -0,0 +1,88 @@ +# DiffSynth-Studio 文档 + +欢迎来到 Diffusion 模型的魔法世界!`DiffSynth-Studio` 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望构建一个通用的 Diffusion 模型框架,以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界! + +
+ +文档阅读导引 + +```mermaid +graph LR; + 我想要使用模型进行推理和训练-->sec1[Section 1: 上手使用]; + 我想要使用模型进行推理和训练-->sec2[Section 2: 模型详解]; + 我想要使用模型进行推理和训练-->sec3[Section 3: 训练框架]; + 我想要基于此框架进行二次开发-->sec3[Section 3: 训练框架]; + 我想要基于此框架进行二次开发-->sec4[Section 4: 模型接入]; + 我想要基于此框架进行二次开发-->sec5[Section 5: API 参考]; + 我想要基于本项目探索新的技术-->sec4[Section 4: 模型接入]; + 我想要基于本项目探索新的技术-->sec5[Section 5: API 参考]; + 我想要基于本项目探索新的技术-->sec6[Section 6: 学术导引]; + 我遇到了问题-->sec7[Section 7: 常见问题]; +``` + +
+ +## Section 1: 上手使用 + +本节介绍 `DiffSynth-Studio` 的基本使用方式,包括如何启用显存管理从而在极低显存的 GPU 上进行推理,以及如何训练任意基础模型、LoRA、ControlNet 等模型。 + +* [安装依赖](/docs/zh/Pipeline_Usage/Setup.md) +* [模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md) +* [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md) +* [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md) +* [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md) + +## Section 2: 模型详解 + +本节介绍 `DiffSynth-Studio` 所支持的 Diffusion 模型,部分模型 Pipeline 具备可控生成、并行加速等特色功能。 + +* [FLUX.1](/docs/zh/Model_Details/FLUX.md) +* [Wan](/docs/zh/Model_Details/Wan.md) +* [Qwen-Image](/docs/zh/Model_Details/Qwen-Image.md) +* [FLUX.2](/docs/zh/Model_Details/FLUX2.md) +* [Z-Image](/docs/zh/Model_Details/Z-Image.md) + +## Section 3: 训练框架 + +本节介绍 `DiffSynth-Studio` 中训练框架的设计思路,帮助开发者理解 Diffusion 模型训练算法的原理。 + +* [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md) +* [标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md) +* [在训练中启用 FP8 精度](/docs/zh/Training/FP8_Precision.md) +* [端到端的蒸馏加速训练](/docs/zh/Training/Direct_Distill.md) +* [两阶段拆分训练](/docs/zh/Training/Split_Training.md) +* [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md) + +## Section 4: 模型接入 + +本节介绍如何将模型接入 `DiffSynth-Studio` 从而使用框架基础功能,帮助开发者为本项目提供新模型的支持,或进行私有化模型的推理和训练。 + +* [接入模型结构](/docs/zh/Developer_Guide/Integrating_Your_Model.md) +* [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) +* [接入细粒度显存管理](/docs/zh/Developer_Guide/Enabling_VRAM_management.md) +* [接入模型训练](/docs/zh/Developer_Guide/Training_Diffusion_Models.md) + +## Section 5: API 参考 + +本节介绍 `DiffSynth-Studio` 中的独立核心模块 `diffsynth.core`,介绍内部的功能是如何设计和运作的,开发者如有需要,可将其中的功能模块用于其他代码库的开发中。 + +* [`diffsynth.core.attention`](/docs/zh/API_Reference/core/attention.md): 注意力机制实现 +* [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md): 数据处理算子与通用数据集 +* [`diffsynth.core.gradient`](/docs/zh/API_Reference/core/gradient.md): 梯度检查点 +* [`diffsynth.core.loader`](/docs/zh/API_Reference/core/loader.md): 模型下载与加载 +* [`diffsynth.core.vram`](/docs/zh/API_Reference/core/vram.md): 显存管理 + +## Section 6: 学术导引 + +本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。 + +* 从零开始训练模型【coming soon】 +* 推理改进优化技术【coming soon】 +* 设计可控生成模型【coming soon】 +* 创建新的训练范式【coming soon】 + +## Section 7: 常见问题 + +本节总结了开发者常见的问题,如果你在使用和开发中遇到了问题,请参考本节内容,如果仍无法解决,请到 GitHub 上给我们提 issue。 + +* [常见问题](/docs/zh/QA.md) diff --git a/docs/zh/Training/Differential_LoRA.md b/docs/zh/Training/Differential_LoRA.md new file mode 100644 index 0000000000000000000000000000000000000000..2489ea0a6a595322ea192ed896699116006ec0a7 --- /dev/null +++ b/docs/zh/Training/Differential_LoRA.md @@ -0,0 +1,38 @@ +# 差分 LoRA 训练 + +差分 LoRA 训练是一种特殊的 LoRA 训练方式,旨在让模型学习图像之间的差异。 + +## 训练方案 + +我们未能找到差分 LoRA 训练最早由谁提出,这一技术已经在开源社区中流传甚久。 + +假设我们有两张内容相似的图像:图 1 和图 2。例如两张图中分别有一辆车,但图 1 中画面细节更少,图 2 中画面细节更多。在差分 LoRA 训练中,我们进行两步训练: + +* 以图 1 为训练数据,以[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 1 +* 以图 2 为训练数据,将 LoRA 1 融入基础模型后,以[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 2 + +在第一步训练中,由于训练数据仅有一张图,LoRA 模型很容易过拟合,因此训练完成后,LoRA 1 会让模型毫不犹豫地生成图 1,无论随机种子是什么。在第二步训练中,LoRA 模型再次过拟合,因此训练完成后,在 LoRA 1 和 LoRA 2 的共同作用下,模型会毫不犹豫地生成图 2。简言之: + +* LoRA 1 = 生成图 1 +* LoRA 1 + LoRA 2 = 生成图 2 + +此时丢弃 LoRA 1,只使用 LoRA 2,模型将会理解图 1 和图 2 的差异,使生成的内容倾向于“更不像图1,更像图 2”。 + +单一训练数据可以保证模型能够过拟合到训练数据上,但稳定性不足。为了提高稳定性,我们可以用多个图像对(image pairs)进行训练,并将训练出的 LoRA 2 进行平均,得到效果更稳定的 LoRA。 + +用这一训练方案,可以训练出一些功能奇特的 LoRA 模型。例如,使用丑陋的和漂亮的图像对,训练提升图像美感的 LoRA;使用细节少的和细节丰富的图像对,训练增加图像细节的 LoRA。 + +## 模型效果 + +我们用差分 LoRA 训练技术训练了几个美学提升 LoRA,可前往对应的模型页面查看生成效果。 + +* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1) +* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) + +## 在训练框架中使用差分 LoRA 训练 + +第一步的训练与普通 LoRA 训练没有任何差异,在第二步的训练命令中,通过 `--preset_lora_path` 参数填入第一步的 LoRA 模型文件路径,并将 `--preset_lora_model` 设置为与 `lora_base_model` 相同的参数,即可将 LoRA 1 加载到基础模型中。 + +## 框架设计思路 + +在训练框架中,`--preset_lora_path` 指向的模型在 `DiffusionTrainingModule` 的 `switch_pipe_to_training_mode` 中完成加载。 diff --git a/docs/zh/Training/Direct_Distill.md b/docs/zh/Training/Direct_Distill.md new file mode 100644 index 0000000000000000000000000000000000000000..946a767ceea2c15027e32a646d4dd6a67b644746 --- /dev/null +++ b/docs/zh/Training/Direct_Distill.md @@ -0,0 +1,97 @@ +# 端到端的蒸馏加速训练 + +## 蒸馏加速训练 + +Diffusion 模型的推理过程通常需要多步迭代,在提升生成效果的同时也让生成过程变得缓慢。通过蒸馏加速训练,可以减少生成清晰内容所需的步数。蒸馏加速训练技术的本质训练目标是让少量步数的生成效果与大量步数的生成效果对齐。 + +蒸馏加速训练的方法是多样的,例如 + +* 对抗式训练 ADD(Adversarial Diffusion Distillation) + * 论文:https://arxiv.org/abs/2311.17042 + * 模型:[stabilityai/sdxl-turbo](https://modelscope.cn/models/stabilityai/sdxl-turbo) +* 渐进式训练 Hyper-SD + * 论文:https://arxiv.org/abs/2404.13686 + * 模型:[ByteDance/Hyper-SD](https://www.modelscope.cn/models/ByteDance/Hyper-SD) + +## 直接蒸馏 + +在训练框架层面,支持这类蒸馏加速训练方案是极其困难的。在训练框架的设计中,我们需要保证训练方案满足以下条件: + +* 通用性:训练方案适用于大多数框架内支持的 Diffusion 模型,而非只能对某个特定模型生效,这是代码框架建设的基本要求。 +* 稳定性:训练方案需保证训练效果稳定,不需要人工进行精细的参数调整,ADD 中的对抗式训练则无法保证稳定性。 +* 简洁性:训练方案不会引入额外的复杂模块,根据奥卡姆剃刀([Occam's Razor](https://en.wikipedia.org/wiki/Occam%27s_razor))原理,复杂解决方案可能引入潜在风险,Hyper-SD 中的 Human Feedback Learning 让训练过程变得过于复杂。 + +因此,在 `DiffSynth-Studio` 的训练框架中,我们设计了一个端到端的蒸馏加速训练方案,我们称为直接蒸馏(Direct Distill),其训练过程的伪代码如下: + +``` +seed = xxx +with torch.no_grad(): + image_1 = pipe(prompt, steps=50, seed=seed, cfg=4) +image_2 = pipe(prompt, steps=4, seed=seed, cfg=1) +loss = torch.nn.functional.mse_loss(image_1, image_2) +``` + +是的,非常端到端的训练方案,稍加训练就可以有立竿见影的效果。 + +## 直接蒸馏训练的模型 + +我们用这个方案基于 Qwen-Image 训练了两个模型: + +* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): 全量蒸馏训练 +* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA 蒸馏训练 + +点击模型链接即可前往模型页面查看模型效果。 + +## 在训练框架中使用蒸馏加速训练 + +首先,需要生成训练数据,请参考[模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md)部分编写推理代码,以足够多的推理步数生成训练数据。 + +以 Qwen-Image 为例,以下代码可以生成一张图片: + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +然后,我们把必要的信息编写成[元数据文件](/docs/zh/API_Reference/core/data.md#元数据): + +```csv +image,prompt,seed,rand_device,num_inference_steps,cfg_scale +distill_qwen/image.jpg,"精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。",0,cpu,4,1 +``` + +这个样例数据集可以直接下载: + +```shell +modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset +``` + +然后开始 LoRA 蒸馏加速训练: + +```shell +bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh +``` + +请注意,在[训练脚本参数](/docs/zh/Pipeline_Usage/Model_Training.md#脚本参数)中,数据集的图像分辨率设置要避免触发缩放处理。当设定 `--height` 和 `--width` 以启用固定分辨率时,所有训练数据必须是以完全一致的宽高生成的;当设定 `--max_pixels` 以启用动态分辨率时,`--max_pixels` 的数值必须大于或等于任一训练图像的像素面积。 + +## 训练框架设计思路 + +直接蒸馏与[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)相比,仅训练的损失函数不同,直接蒸馏的损失函数是 `diffsynth.diffusion.loss` 中的 `DirectDistillLoss`。 + +## 未来工作 + +直接蒸馏是通用性很强的加速方案,但未必是效果最好的方案,所以我们暂未把这一技术以论文的形式发布。我们希望把这个问题交给学术界和开源社区共同解决,期待开发者能够给出更完善的通用训练方案。 diff --git a/docs/zh/Training/FP8_Precision.md b/docs/zh/Training/FP8_Precision.md new file mode 100644 index 0000000000000000000000000000000000000000..a1f428aad3d6b720174e325ed467a629823ad617 --- /dev/null +++ b/docs/zh/Training/FP8_Precision.md @@ -0,0 +1,20 @@ +# 在训练中启用 FP8 精度 + +尽管 `DiffSynth-Studio` 在模型推理中支持[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),但其中的大部分减少显存占用的技术不适合用于训练中,Offload 会导致极为缓慢的训练过程。 + +FP8 精度是唯一可在训练过程中启用的显存管理策略,但本框架目前不支持原生 FP8 精度训练,原因详见 [Q&A: 为什么训练框架不支持原生 FP8 精度训练?](/docs/zh/QA.md#为什么训练框架不支持原生-fp8-精度训练),仅支持将参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)以 FP8 精度进行存储。 + +## 启用 FP8 + +在我们提供的训练脚本中,通过参数 `--fp8_models` 即可快速设置以 FP8 精度存储的模型。以 Qwen-Image 的 LoRA 训练为例,我们提供了启用 FP8 训练的脚本,位于 [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh)。训练完成后,可通过脚本 [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py) 验证训练效果。 + +请注意,这种 FP8 显存管理策略不支持梯度更新,当某个模型被设置为可训练时,不能为这个模型开启 FP8 精度,支持开启 FP8 的模型包括两类: + +* 参数不可训练,例如 VAE 模型 +* 梯度不更新其参数,例如 LoRA 训练中的 DiT 模型 + +经实验验证,开启 FP8 后的 LoRA 训练效果没有明显的图像质量下降,但理论上误差是确实存在的,如果在使用本功能时遇到训练效果不如 BF16 精度训练的问题,请通过 GitHub issue 给我们提供反馈。 + +## 训练框架设计思路 + +训练框架完全沿用推理的显存管理,在训练中仅通过 `DiffusionTrainingModule` 中的 `parse_model_configs` 解析显存管理配置。 diff --git a/docs/zh/Training/Split_Training.md b/docs/zh/Training/Split_Training.md new file mode 100644 index 0000000000000000000000000000000000000000..f98d56e087c4bd8f7d07ecf64aeda2e10f0e4aaf --- /dev/null +++ b/docs/zh/Training/Split_Training.md @@ -0,0 +1,97 @@ +# 两阶段拆分训练 + +本文档介绍拆分训练,能够自动将训练过程拆分为两阶段进行,减少显存占用,同时加快训练速度。 + +(拆分训练是实验性特性,尚未进行大规模验证,如果在使用中出现问题,请在 GitHub 上给我们提 issue。) + +## 拆分训练 + +在大部分模型的训练过程中,大量计算发生在“前处理”中,即“与去噪模型无关的计算”,包括 VAE 编码、文本编码等。当对应的模型参数固定时,这部分计算的结果是重复的,在多个 epoch 中每个数据样本的计算结果完全相同,因此我们提供了“拆分训练”功能,该功能可以自动分析并拆分训练过程。 + +对于普通文生图模型的标准监督训练,拆分过程是非常简单的,只需要把所有 [`Pipeline Units`](/docs/zh/Developer_Guide/Building_a_Pipeline.md#units) 的计算拆分到第一阶段,将计算结果存储到硬盘中,然后在第二阶段从硬盘中读取这些结果并进行后续计算即可。但如果前处理过程中需要梯度回传,情况就变得极其复杂,为此,我们引入了一个计算图拆分算法用于分析如何拆分计算。 + +## 计算图拆分算法 + +> (我们会在后续的文档更新中补充计算图拆分算法的详细细节) + +## 使用拆分训练 + +拆分训练已支持[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)和[直接蒸馏训练](/docs/zh/Training/Direct_Distill.md),在训练命令中通过 `--task` 参数控制,以 Qwen-Image 模型的 LoRA 训练为例,拆分前的训练命令为: + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters +``` + +拆分后,在第一阶段中,做如下修改: + +* 将 `--dataset_repeat` 改为 1,避免重复计算 +* 将 `--output_path` 改为第一阶段计算结果保存的路径 +* 添加额外参数 `--task "sft:data_process"` +* 删除 `--model_id_with_origin_paths` 中的 DiT 模型 + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:data_process" +``` + +在第二阶段,做如下修改: + +* 将 `--dataset_base_path` 改为第一阶段的 `--output_path` +* 删除 `--dataset_metadata_path` +* 添加额外参数 `--task "sft:train"` +* 删除 `--model_id_with_origin_paths` 中的 Text Encoder 和 VAE 模型 + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:train" +``` + +我们提供了样例训练脚本和验证脚本,位于 `examples/qwen_image/model_training/special/split_training`。 + +## 训练框架设计思路 + +训练框架通过 `DiffusionTrainingModule` 的 `split_pipeline_units` 方法拆分 `Pipeline` 中的计算单元。 diff --git a/docs/zh/Training/Supervised_Fine_Tuning.md b/docs/zh/Training/Supervised_Fine_Tuning.md new file mode 100644 index 0000000000000000000000000000000000000000..f2f8aa354c00d8e2a5f04bab66ebe94a4bedb0fd --- /dev/null +++ b/docs/zh/Training/Supervised_Fine_Tuning.md @@ -0,0 +1,129 @@ +# 标准监督训练 + +在理解 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)之后,本文档介绍框架如何实现 Diffusion 模型的训练。本文档介绍框架的原理,帮助开发者编写新的训练代码,如需使用我们提供的默认训练功能,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)。 + +回顾前文中的模型训练伪代码,当我们实际编写代码时,情况会变得极为复杂。部分模型需要输入额外的引导条件并进行预处理,例如 ControlNet;部分模型需要与去噪模型进行交叉式的计算,例如 VACE;部分模型因显存需求过大,需要开启 Gradient Checkpointing,例如 Qwen-Image 的 DiT。 + +为了实现严格的推理和训练一致性,我们对 `Pipeline` 等组件进行了抽象封装,在训练过程中大量复用推理代码。请参考[接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 了解 `Pipeline` 组件的设计。接下来我们介绍训练框架如何利用 `Pipeline` 组件构建训练算法。 + +## 框架设计思路 + +训练模块在 `Pipeline` 上层进行封装,继承 `diffsynth.diffusion.training_module` 中的 `DiffusionTrainingModule`,我们需为训练模块提供必要的 `__init__` 和 `forward` 方法。我们以 Qwen-Image 的 LoRA 训练为例,在 `examples/qwen_image/model_training/special/simple/train.py` 中提供了仅包含基础训练功能的简易脚本,帮助开发者理解训练模块的设计思路。 + +```python +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__(self, device): + # Initialize models here. + pass + + def forward(self, data): + # Compute loss here. + return loss +``` + +### `__init__` + +在 `__init__` 中需进行模型的初始化,先加载模型,然后将其切换到训练模式。 + +```python + def __init__(self, device): + super().__init__() + # Load the pipeline + self.pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + ) + # Switch to training mode + self.switch_pipe_to_training_mode( + self.pipe, + lora_base_model="dit", + lora_target_modules="to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj", + lora_rank=32, + ) +``` + +加载模型的逻辑与推理时基本一致,支持从远程和本地路径加载模型,详见[模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md),但请注意不要启用[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。 + +`switch_pipe_to_training_mode` 可以将模型切换到训练模式,详见 `switch_pipe_to_training_mode`。 + +### `forward` + +在 `forward` 中需计算损失函数值,先进行前处理,然后经过 `Pipeline` 的 [`model_fn`](/docs/zh/Developer_Guide/Building_a_Pipeline.md#model_fn) 计算损失函数。 + +```python + def forward(self, data): + # Preprocess + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": True, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + # Loss + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss +``` + +前处理过程与推理阶段一致,开发者只需假定在使用 `Pipeline` 进行推理,将输入参数填入即可。 + +损失函数的计算沿用 `diffsynth.diffusion.loss` 中的 `FlowMatchSFTLoss`。 + +### 开始训练 + +训练框架还需其他模块,包括: + +* accelerator: `accelerate` 提供的训练启动器,详见 [`accelerate`](https://huggingface.co/docs/accelerate/index) +* dataset: 通用数据集,详见 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md) +* model_logger: 模型记录器,详见 `diffsynth.diffusion.logger` + +```python +if __name__ == "__main__": + accelerator = accelerate.Accelerator( + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)], + ) + dataset = UnifiedDataset( + base_path="data/example_image_dataset", + metadata_path="data/example_image_dataset/metadata.csv", + repeat=50, + data_file_keys="image", + main_data_operator=UnifiedDataset.default_image_operator( + base_path="data/example_image_dataset", + height=512, + width=512, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule(accelerator.device) + model_logger = ModelLogger( + output_path="models/toy_model", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=1e-5, num_epochs=1, + ) +``` + +将以上所有代码组装,得到 `examples/qwen_image/model_training/special/simple/train.py`。使用以下命令即可启动训练: + +``` +accelerate launch examples/qwen_image/model_training/special/simple/train.py +``` diff --git a/docs/zh/Training/Understanding_Diffusion_models.md b/docs/zh/Training/Understanding_Diffusion_models.md new file mode 100644 index 0000000000000000000000000000000000000000..576edc9c6a5acb62b2eb4605c191d2d0b747b6fd --- /dev/null +++ b/docs/zh/Training/Understanding_Diffusion_models.md @@ -0,0 +1,143 @@ +# Diffusion 模型基本原理 + +本文介绍 Diffusion 模型的基本原理,帮助你理解训练框架是如何构建的。为了让读者更轻松地理解这些复杂的数学理论,我们重构了 Diffusion 模型的理论框架,抛弃了复杂的随机微分方程,用一种更简洁易懂的形式进行介绍。 + +## 引言 + +Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。 + +(图) + +这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题: + +* 每一步的噪声含量是如何定义的? +* 迭代去噪的计算是如何进行的? +* 如何训练这样的 Diffusion 模型? +* 现代 Diffusion 模型的架构是什么样的? +* 本项目如何封装和实现模型训练? + +## 每一步的噪声含量是如何定义的? + +在 Diffusion 模型的理论体系中,噪声的含量是由一系列参数 $\sigma_T$、$\sigma_{T-1}$、$\sigma_{T-2}$、$\cdots$、$\sigma_0$ 决定的。其中 + +* $\sigma_T=1$,对应的 $x_T$ 为纯粹的高斯噪声 +* $\sigma_T>\sigma_{T-1}>\sigma_{T-2}>\cdots>x_0$,在迭代过程中噪声含量逐渐减小 +* $\sigma_0=0$,对应的 $x_0$ 为不含任何噪声的数据样本 + +至于中间 $\sigma_{T-1}$、$\sigma_{T-2}$、$\cdots$、$\sigma_1$ 的数值,则不是固定的,满足递减的条件即可。 + +那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。 + +(图) + +## 迭代去噪的计算是如何进行的? + +在理解迭代去噪的计算前,我们要先搞清楚,去噪模型的输入和输出是什么。我们把模型抽象成一个符号 $\hat \epsilon$,它的输入通常包含三部分 + +* 时间步 $t$,模型需要理解当前处于去噪过程的哪个阶段 +* 含噪声的数据样本 $x_t$,模型需要理解要对什么数据进行去噪 +* 引导条件 $c$,模型需要理解要通过去噪生成什么样的数据样本 + +其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。 + +(图) + +而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。 + +接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$: +$$ +\begin{aligned} +x_{t-1}&=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)\\ +&\approx x_t + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\ +&=(1-\sigma_t)x_0+\sigma_t x_T + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\ +&=(1-\sigma_{t-1})x_0+\sigma_{t-1}x_T +\end{aligned} +$$ +完美!与时间步 $t-1$ 时的噪声含量定义完美契合。 + +> (这部分可能有点难懂,请不必担心,首次阅读本文时建议跳过这部分,不影响后文的阅读。) +> +> 完成了这段有点复杂的公式推导后,我们思考一个问题,为什么模型的输出要近似地等于 $x_T-x_0$ 呢?可以设定成其他值吗? +> +> 实际上,Diffusion 模型依赖两个定义形成完备的理论。在以上的公式中,我们可以提炼出这两个定义,并导出迭代公式: +> +> * 数据定义:$x_t=(1-\sigma_t)x_0+\sigma_t x_T$ +> * 模型定义:$\hat \epsilon(x_t,c,t)=x_T-x_0$ +> * 导出迭代公式:$x_{t-1}=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)$ +> +> 这三个数学公式是完备的,例如在刚才的推导中,我们把数据定义和模型定义代入迭代公式,可以得到与数据定义吻合的 $x_{t-1}$。 +> +> 这是基于 Flow Matching 理论构建的两个定义,但 Diffusion 模型也可用其他的两个定义来实现,例如早期基于 DDPM(Denoising Diffusion Probabilistic Models)的模型,其两个定义及导出的迭代公式为: +> +> * 数据定义:$x_t=\sqrt{\alpha_t}x_0+\sqrt{1-\alpha_t}x_T$ +> * 模型定义:$\hat \epsilon(x_t,c,t)=x_T$ +> * 导出迭代公式:$x_{t-1}=\sqrt{\alpha_{t-1}}\left(\frac{x_t-\sqrt{1-\alpha_t}\hat \epsilon(x_t,c,t)}{\sqrt{\sigma_t}}\right)+\sqrt{1-\alpha_{t-1}}\hat \epsilon(x_t,c,t)$ +> +> 更一般地,我们用矩阵描述迭代公式的导出过程,对于任意数据定义和模型定义,有: +> +> * 数据定义:$x_t=C_T(x_0,x_T)^T$ +> * 模型定义:$\hat \epsilon(x_t,c,t)=C_T^{[\epsilon]}(x_0,x_T)^T$ +> * 导出迭代公式:$x_{t-1}=C_{t-1}(C_t,C_t^{[\epsilon]})^{-T}(x_t,\hat \epsilon(x_t,c,t))^T$ +> +> 其中,$C_t$、$C_t^{[\epsilon]}$ 是 $1\times 2$ 的系数矩阵,不难发现,在构造两个定义时,需保证矩阵 $(C_t,C_t^{[\epsilon]})^T$ 是可逆的。 +> +> 尽管 Flow Matching 与 DDPM 已被大量预训练模型广泛验证过,但这并不代表这是最优的方案,我们鼓励开发者设计新的 Diffusion 模型理论实现更好的训练效果。 + +## 如何训练这样的 Diffusion 模型? + +搞清楚迭代去噪的过程之后,接下来我们考虑如何训练这样的 Diffusion 模型。 + +训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。 + +(图) + +以下是训练过程的伪代码 + +> 从数据集获取数据样本 $x_0$ 和引导条件 $c$ +> +> 随机采样时间步 $t\in(0,T]$ +> +> 随机采样高斯噪声 $x_T\in \mathcal N(O,I)$ +> +> $x_t=(1-\sigma_t)x_0+\sigma_t x_T$ +> +> $\hat \epsilon(x_t,c,t)$ +> +> 损失函数 $\mathcal L=||\hat \epsilon(x_t,c,t)-(x_T-x_0)||_2^2$ +> +> 梯度回传并更新模型参数 + +## 现代 Diffusion 模型的架构是什么样的? + +从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。 + +(图) + +### 数据编解码器 + +在前文中,我们一直将 $x_0$ 称为“数据样本”,而不是图像或视频,这是因为现代 Diffusion 模型通常不会直接在图像或视频上进行处理,而是用编码器(Encoder)-解码器(Decoder)架构的模型,通常是 VAE(Variational Auto-Encoders)模型,将图像或视频编码为 Embedding 张量,得到 $x_0$。 + +数据经过编码器编码后,再经过解码器解码,重建后的内容与原来近似地一致,会有少量误差。那么,为什么要在编码后的 Embedding 张量上处理,而不是在图像或视频上直接处理呢?主要原因有亮点: + +* 编码的同时对数据进行了压缩,编码后处理的计算量更小。 +* 编码后的数据分布与高斯分布更相似,更容易用去噪模型对数据进行建模。 + +在生成过程中,编码器部分不参与计算,迭代完成后,用解码器部分解码 $x_0$ 即可得到清晰的图像或视频。在训练过程中,解码器部分不参与计算,仅编码器用于计算 $x_0$。 + +### 引导条件编码器 + +用户输入的引导条件 $c$ 可能是复杂多样的,需要由专门的编码器模型将其处理成 Embedding 张量。按照引导条件的类型,我们把引导条件编码器分为以下几类: + +* 文本类型,例如 CLIP、Qwen-VL +* 图像类型,例如 ControlNet、IP-Adapter +* 视频类型,例如 VAE + +> 前文中的模型 $\hat \epsilon$ 指代此处的所有引导条件编码器和去噪模型这一整体,我们把引导条件编码器单独拆分列出,因为这类模型在 Diffusion 训练中通常是冻结的,且输出值与时间步 $t$ 无关,因此引导条件编码器的计算可以离线进行。 + +### 去噪模型 + +去噪模型是 Diffusion 模型真正的本体,其模型结构多种多样,例如 UNet、DiT,模型开发者可在此结构上自由发挥。 + +## 本项目如何封装和实现模型训练? + +请阅读下一文档:[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md) diff --git a/examples/Comp-Attn.pdf b/examples/Comp-Attn.pdf new file mode 100644 index 0000000000000000000000000000000000000000..21a5770abaaacaa92eb2a4a955db410bf2d813cb --- /dev/null +++ b/examples/Comp-Attn.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f5c832ca31c7f21f9ea4c2826605dd34e3750fd9ed8da06b89a78fb8b11d69b +size 6067768 diff --git a/examples/InstanceV.pdf b/examples/InstanceV.pdf new file mode 100644 index 0000000000000000000000000000000000000000..117370687f13a06022c390a0dd0adbfcdc5d2276 --- /dev/null +++ b/examples/InstanceV.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcbdc6792ce5c52bbbd3178e60dab38e80057ec6fa3c65fcdbdd196149581293 +size 16723166 diff --git a/examples/comp_attn_trajectory.png b/examples/comp_attn_trajectory.png new file mode 100644 index 0000000000000000000000000000000000000000..c34ebf660052310bdf7cfcb1583c3d8f3b276107 Binary files /dev/null and b/examples/comp_attn_trajectory.png differ diff --git a/examples/dev_tools/fix_path.py b/examples/dev_tools/fix_path.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd1733a6bc110b5cf31f75229c6d7f508ced8ae --- /dev/null +++ b/examples/dev_tools/fix_path.py @@ -0,0 +1,43 @@ +import re, os + + +def read_file(path): + with open(path, "r", encoding="utf-8-sig") as f: + context = f.read() + return context + +def get_files(files, path): + if os.path.isdir(path): + for folder in os.listdir(path): + get_files(files, os.path.join(path, folder)) + elif path.endswith(".md"): + files.append(path) + +def fix_path(doc_root_path): + files = [] + get_files(files, doc_root_path) + file_map = {} + for file in files: + name = file.split("/")[-1] + file_map[name] = "/" + file + + pattern = re.compile(r'\]\([^)]*\.md') + for file in files: + context = read_file(file) + matches = pattern.findall(context) + + edited = False + for match in matches: + target = "](" + file_map[match.split("/")[-1].replace("](", "")] + context = context.replace(match, target) + if target != match: + print(match, target) + edited = True + print(file, match, target) + + if edited: + with open(file, "w", encoding="utf-8") as f: + f.write(context) + +fix_path("doc/zh") +fix_path("doc/en") \ No newline at end of file diff --git a/examples/dev_tools/unit_test.py b/examples/dev_tools/unit_test.py new file mode 100644 index 0000000000000000000000000000000000000000..364af476a5ae4312d6bdc25541fe38a8eaf71e2c --- /dev/null +++ b/examples/dev_tools/unit_test.py @@ -0,0 +1,114 @@ +import os, shutil, multiprocessing, time +NUM_GPUS = 7 + + +def script_is_processed(output_path, script): + return os.path.exists(os.path.join(output_path, script)) and "log.txt" in os.listdir(os.path.join(output_path, script)) + + +def filter_unprocessed_tasks(script_path): + tasks = [] + output_path = os.path.join("data", script_path) + for script in sorted(os.listdir(script_path)): + if not script.endswith(".sh") and not script.endswith(".py"): + continue + if script_is_processed(output_path, script): + continue + tasks.append(script) + return tasks + + +def run_inference(script_path): + tasks = filter_unprocessed_tasks(script_path) + output_path = os.path.join("data", script_path) + for script in tasks: + source_path = os.path.join(script_path, script) + target_path = os.path.join(output_path, script) + os.makedirs(target_path, exist_ok=True) + cmd = f"python {source_path} > {target_path}/log.txt 2>&1" + print(cmd, flush=True) + os.system(cmd) + for file_name in os.listdir("./"): + if file_name.endswith(".jpg") or file_name.endswith(".png") or file_name.endswith(".mp4"): + shutil.move(file_name, os.path.join(target_path, file_name)) + + +def run_tasks_on_single_GPU(script_path, tasks, gpu_id, num_gpu): + output_path = os.path.join("data", script_path) + for script_id, script in enumerate(tasks): + if script_id % num_gpu != gpu_id: + continue + source_path = os.path.join(script_path, script) + target_path = os.path.join(output_path, script) + os.makedirs(target_path, exist_ok=True) + if script.endswith(".sh"): + cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} bash {source_path} > {target_path}/log.txt 2>&1" + elif script.endswith(".py"): + cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python {source_path} > {target_path}/log.txt 2>&1" + print(cmd, flush=True) + os.system(cmd) + + +def run_train_multi_GPU(script_path): + tasks = filter_unprocessed_tasks(script_path) + output_path = os.path.join("data", script_path) + for script in tasks: + source_path = os.path.join(script_path, script) + target_path = os.path.join(output_path, script) + os.makedirs(target_path, exist_ok=True) + cmd = f"bash {source_path} > {target_path}/log.txt 2>&1" + print(cmd, flush=True) + os.system(cmd) + time.sleep(1) + + +def run_train_single_GPU(script_path): + tasks = filter_unprocessed_tasks(script_path) + processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, NUM_GPUS)) for i in range(NUM_GPUS)] + for p in processes: + p.start() + for p in processes: + p.join() + + +def move_files(prefix, target_folder): + os.makedirs(target_folder, exist_ok=True) + os.system(f"cp -r {prefix}* {target_folder}") + os.system(f"rm -rf {prefix}*") + + +def test_qwen_image(): + run_inference("examples/qwen_image/model_inference") + run_inference("examples/qwen_image/model_inference_low_vram") + run_train_multi_GPU("examples/qwen_image/model_training/full") + run_inference("examples/qwen_image/model_training/validate_full") + run_train_single_GPU("examples/qwen_image/model_training/lora") + run_inference("examples/qwen_image/model_training/validate_lora") + + +def test_wan(): + run_train_single_GPU("examples/wanvideo/model_inference") + move_files("video_", "data/output/model_inference") + run_train_single_GPU("examples/wanvideo/model_inference_low_vram") + move_files("video_", "data/output/model_inference_low_vram") + run_train_multi_GPU("examples/wanvideo/model_training/full") + run_train_single_GPU("examples/wanvideo/model_training/validate_full") + move_files("video_", "data/output/validate_full") + run_train_single_GPU("examples/wanvideo/model_training/lora") + run_train_single_GPU("examples/wanvideo/model_training/validate_lora") + move_files("video_", "data/output/validate_lora") + + +def test_flux(): + run_inference("examples/flux/model_inference") + run_inference("examples/flux/model_inference_low_vram") + run_train_multi_GPU("examples/flux/model_training/full") + run_inference("examples/flux/model_training/validate_full") + run_train_single_GPU("examples/flux/model_training/lora") + run_inference("examples/flux/model_training/validate_lora") + + +if __name__ == "__main__": + test_qwen_image() + test_flux() + test_wan() diff --git a/examples/flux/model_inference/FLEX.2-preview.py b/examples/flux/model_inference/FLEX.2-preview.py new file mode 100644 index 0000000000000000000000000000000000000000..efc8e9143c299b3a408967dbde898f78f959a95a --- /dev/null +++ b/examples/flux/model_inference/FLEX.2-preview.py @@ -0,0 +1,50 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth.utils.controlnet import Annotator +import numpy as np +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 +) +image.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[200:400, 400:700] = 255 +mask = Image.fromarray(mask) +mask.save("image_mask.jpg") + +inpaint_image = image + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, + seed=4 +) +image.save("image_2.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save("image_3.jpg") diff --git a/examples/flux/model_inference/FLUX.1-Kontext-dev.py b/examples/flux/model_inference/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..e7aae1bb2410d397585082607697aa5a778a64f0 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-Kontext-dev.py @@ -0,0 +1,54 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image_1 = pipe( + prompt="a beautiful Asian long-haired female college student.", + embedded_guidance=2.5, + seed=1, +) +image_1.save("image_1.jpg") + +image_2 = pipe( + prompt="transform the style to anime style.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=2, +) +image_2.save("image_2.jpg") + +image_3 = pipe( + prompt="let her smile.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=3, +) +image_3.save("image_3.jpg") + +image_4 = pipe( + prompt="let the girl play basketball.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=4, +) +image_4.save("image_4.jpg") + +image_5 = pipe( + prompt="move the girl to a park, let her sit on a chair.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=5, +) +image_5.save("image_5.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-Krea-dev.py b/examples/flux/model_inference/FLUX.1-Krea-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..978a26a90d460c1723b4865ff5e36e5d98cddd17 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-Krea-dev.py @@ -0,0 +1,27 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +prompt = "An beautiful woman is riding a bicycle in a park, wearing a red dress" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe(prompt=prompt, seed=0, embedded_guidance=4.5) +image.save("flux_krea.jpg") + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + seed=0, cfg_scale=2, num_inference_steps=50, + embedded_guidance=4.5 +) +image.save("flux_krea_cfg.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py b/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py new file mode 100644 index 0000000000000000000000000000000000000000..b35cce8768f3aae7dbc7b7a6ddf7fd50eb029d6d --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py @@ -0,0 +1,19 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + ], +) + +for i in [0.1, 0.3, 0.5, 0.7, 0.9]: + image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i]) + image.save(f"value_control_{i}.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0d1f3a8259a61f3cd6cb1b5fbcf256f9f43c24 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py @@ -0,0 +1,37 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +import numpy as np +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) + +image_1 = pipe( + prompt="a cat sitting on a chair", + height=1024, width=1024, + seed=8, rand_device="cuda", +) +image_1.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[100:350, 350: -300] = 255 +mask = Image.fromarray(mask) +mask.save("mask.jpg") + +image_2 = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)], + height=1024, width=1024, + seed=9, rand_device="cuda", +) +image_2.save("image_2.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa10aac083ad338028aeb30a99b63efdfa656d9 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py @@ -0,0 +1,40 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth.utils.controlnet import Annotator +from modelscope import snapshot_download + + + +snapshot_download("sd_lora/Annotators", allow_file_pattern="dpt_hybrid-midas-501f0c75.pt", local_dir="models/Annotators") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) + +image_1 = pipe( + prompt="a beautiful Asian girl, full body, red dress, summer", + height=1024, width=1024, + seed=6, rand_device="cuda", +) +image_1.save("image_1.jpg") + +image_canny = Annotator("canny")(image_1) +image_depth = Annotator("depth")(image_1) + +image_2 = pipe( + prompt="a beautiful Asian girl, full body, red dress, winter", + controlnet_inputs=[ + ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"), + ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"), + ], + height=1024, width=1024, + seed=7, rand_device="cuda", +) +image_2.save("image_2.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c288df1668b8731e5a80f367347819408f71bf --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py @@ -0,0 +1,33 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) + +image_1 = pipe( + prompt="a photo of a cat, highly detailed", + height=768, width=768, + seed=0, rand_device="cuda", +) +image_1.save("image_1.jpg") + +image_1 = image_1.resize((2048, 2048)) +image_2 = pipe( + prompt="a photo of a cat, highly detailed", + controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)], + input_image=image_1, + denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=1, rand_device="cuda", +) +image_2.save("image_2.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-EliGen.py b/examples/flux/model_inference/FLUX.1-dev-EliGen.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc4d2e1bbe3b5591666f15e00d9098d562f4b6b --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-EliGen.py @@ -0,0 +1,133 @@ +import random +import torch +from PIL import Image, ImageDraw, ImageFont +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="DiffSynth-Studio/Eligen", origin_file_pattern="model_bf16.safetensors"), alpha=1) + +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) + +# example 3 +global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning," +entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"] +example(pipe, [27], 3, global_prompt, entity_prompts) + +# example 4 +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +example(pipe, [21], 4, global_prompt, entity_prompts) + +# example 5 +global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere." +entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"] +example(pipe, [0], 5, global_prompt, entity_prompts) + +# example 6 +global_prompt = "Snow White and the 6 Dwarfs." +entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"] +example(pipe, [8], 6, global_prompt, entity_prompts) + +# example 7, same prompt with different seeds +seeds = range(5, 9) +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts) diff --git a/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..1479e1da4bf59a1ab0b669147cb3488bcff3c66b --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"), + ], +) + +origin_prompt = "a rabbit in a garden, colorful flowers" +image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42) +image.save("style image.jpg") + +image = pipe(prompt="A piggy", height=1280, width=960, seed=42, + ipadapter_images=[image], ipadapter_scale=0.7) +image.save("A piggy.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py new file mode 100644 index 0000000000000000000000000000000000000000..4491ccb48e8642b4b1afe8a9bce74df0059b9b79 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py @@ -0,0 +1,61 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from modelscope import snapshot_download +from PIL import Image +import numpy as np + +# This model has additional requirements. +# Please install the following packages. +# pip install facexlib insightface onnxruntime +snapshot_download( + "ByteDance/InfiniteYou", + allow_file_pattern="supports/insightface/models/antelopev2/*", + local_dir="models/ByteDance/InfiniteYou", +) +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"), + ], +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/infiniteyou/*", +) + +height, width = 1024, 1024 +controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8)) +controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")] + +prompt = "A man, portrait, cinematic" +id_image = "data/examples/infiniteyou/man.jpg" +id_image = Image.open(id_image).convert('RGB') +image = pipe( + prompt=prompt, seed=1, + infinityou_id_image=id_image, infinityou_guidance=1.0, + controlnet_inputs=controlnet_inputs, + num_inference_steps=50, embedded_guidance=3.5, + height=height, width=width, +) +image.save("man.jpg") + +prompt = "A woman, portrait, cinematic" +id_image = "data/examples/infiniteyou/woman.jpg" +id_image = Image.open(id_image).convert('RGB') +image = pipe( + prompt=prompt, seed=1, + infinityou_id_image=id_image, infinityou_guidance=1.0, + controlnet_inputs=controlnet_inputs, + num_inference_steps=50, embedded_guidance=3.5, + height=height, width=width, +) +image.save("woman.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75f1bc80444aebd54290b46be9e4d9eb3c8d2e8d --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py @@ -0,0 +1,38 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), + ], +) +lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") +pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA. + +# Empty prompt can automatically activate LoRA capabilities. +image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) +image.save("image_1.jpg") + +image = pipe(prompt="", seed=0) +image.save("image_1_origin.jpg") + +# Prompt without trigger words can also activate LoRA capabilities. +image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora) +image.save("image_2.jpg") + +image = pipe(prompt="a car", seed=0,) +image.save("image_2_origin.jpg") + +# Adjust the activation intensity through the scale parameter. +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0) +image.save("image_3.jpg") + +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5) +image.save("image_3_scale.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..5076a7abd13775f9ca235ca5b1f4773636419c8d --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py @@ -0,0 +1,38 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + # Enable lora hotloading + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors"), + ], +) +pipe.enable_lora_merger() + +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), +) +image = pipe(prompt="a cat", seed=0) +image.save("image_fused.jpg") diff --git a/examples/flux/model_inference/FLUX.1-dev.py b/examples/flux/model_inference/FLUX.1-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..35d1e96123f6612722b59c0bc5ab31abd08774c4 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev.py @@ -0,0 +1,26 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe(prompt=prompt, seed=0) +image.save("flux.jpg") + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + seed=0, cfg_scale=2, num_inference_steps=50, +) +image.save("flux_cfg.jpg") diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py new file mode 100644 index 0000000000000000000000000000000000000000..67691659a0bc375b77f99bc0fa60c467f22ba787 --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -0,0 +1,37 @@ +import importlib +import torch +from PIL import Image +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], + nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), +) + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("cat_crown.jpg") diff --git a/examples/flux/model_inference/Nexus-Gen-Generation.py b/examples/flux/model_inference/Nexus-Gen-Generation.py new file mode 100644 index 0000000000000000000000000000000000000000..5130d670672b6db41c71abf900e966ab862be57e --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Generation.py @@ -0,0 +1,32 @@ +import importlib +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), +) + +prompt = "一只可爱的猫咪" +image = pipe( + prompt=prompt, negative_prompt="", + seed=0, cfg_scale=3, num_inference_steps=50, + height=1024, width=1024, +) +image.save("cat.jpg") diff --git a/examples/flux/model_inference/Step1X-Edit.py b/examples/flux/model_inference/Step1X-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec517b6c35763e8f0a332d44e93bf5dfb282784 --- /dev/null +++ b/examples/flux/model_inference/Step1X-Edit.py @@ -0,0 +1,32 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image +import numpy as np + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), + ], +) + +image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) +image = pipe( + prompt="draw red flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=1, rand_device='cuda' +) +image.save("image_1.jpg") + +image = pipe( + prompt="add more flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=2, rand_device='cuda' +) +image.save("image_2.jpg") diff --git a/examples/flux/model_inference_low_vram/FLEX.2-preview.py b/examples/flux/model_inference_low_vram/FLEX.2-preview.py new file mode 100644 index 0000000000000000000000000000000000000000..a4454e84fedaf6405ee2ae7db152a234edea4b98 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLEX.2-preview.py @@ -0,0 +1,61 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth.utils.controlnet import Annotator +import numpy as np +from PIL import Image + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 +) +image.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[200:400, 400:700] = 255 +mask = Image.fromarray(mask) +mask.save("image_mask.jpg") + +inpaint_image = image + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, + seed=4 +) +image.save("image_2.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save("image_3.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py b/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..2994a3301c9c9d57d88a7cfcf43b2c8bdde54812 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py @@ -0,0 +1,65 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image_1 = pipe( + prompt="a beautiful Asian long-haired female college student.", + embedded_guidance=2.5, + seed=1, +) +image_1.save("image_1.jpg") + +image_2 = pipe( + prompt="transform the style to anime style.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=2, +) +image_2.save("image_2.jpg") + +image_3 = pipe( + prompt="let her smile.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=3, +) +image_3.save("image_3.jpg") + +image_4 = pipe( + prompt="let the girl play basketball.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=4, +) +image_4.save("image_4.jpg") + +image_5 = pipe( + prompt="move the girl to a park, let her sit on a chair.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=5, +) +image_5.save("image_5.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py b/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..2ceb064f600c233e4061ee057bebb55184f964df --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py @@ -0,0 +1,38 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "An beautiful woman is riding a bicycle in a park, wearing a red dress" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe(prompt=prompt, seed=0, embedded_guidance=4.5) +image.save("flux_krea.jpg") + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + seed=0, cfg_scale=2, num_inference_steps=50, + embedded_guidance=4.5 +) +image.save("flux_krea_cfg.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py new file mode 100644 index 0000000000000000000000000000000000000000..e0226ba79ffef81e629e884ae1d067d34647389e --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-AttriCtrl.py @@ -0,0 +1,30 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors", **vram_config) + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +for i in [0.1, 0.3, 0.5, 0.7, 0.9]: + image = pipe(prompt="a cat on the beach", seed=2, value_controller_inputs=[i]) + image.save(f"value_control_{i}.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py new file mode 100644 index 0000000000000000000000000000000000000000..61ac25f805a2fa0f0ddcf977ad9b507e29d00de2 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py @@ -0,0 +1,48 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +import numpy as np +from PIL import Image + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image_1 = pipe( + prompt="a cat sitting on a chair", + height=1024, width=1024, + seed=8, rand_device="cuda", +) +image_1.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[100:350, 350: -300] = 255 +mask = Image.fromarray(mask) +mask.save("mask.jpg") + +image_2 = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)], + height=1024, width=1024, + seed=9, rand_device="cuda", +) +image_2.save("image_2.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..148e7ef95f08933bf1c4128ee59c7a31aa418bb0 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py @@ -0,0 +1,50 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth.utils.controlnet import Annotator +from modelscope import snapshot_download + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +snapshot_download("sd_lora/Annotators", allow_file_pattern="dpt_hybrid-midas-501f0c75.pt", local_dir="models/Annotators") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image_1 = pipe( + prompt="a beautiful Asian girl, full body, red dress, summer", + height=1024, width=1024, + seed=6, rand_device="cuda", +) +image_1.save("image_1.jpg") + +image_canny = Annotator("canny")(image_1) +image_depth = Annotator("depth")(image_1) + +image_2 = pipe( + prompt="a beautiful Asian girl, full body, red dress, winter", + controlnet_inputs=[ + ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"), + ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"), + ], + height=1024, width=1024, + seed=7, rand_device="cuda", +) +image_2.save("image_2.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..ca7c72c04379e1f7dbcb3b935f3f9255719c5b67 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py @@ -0,0 +1,44 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image_1 = pipe( + prompt="a photo of a cat, highly detailed", + height=768, width=768, + seed=0, rand_device="cuda", +) +image_1.save("image_1.jpg") + +image_1 = image_1.resize((2048, 2048)) +image_2 = pipe( + prompt="a photo of a cat, highly detailed", + controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)], + input_image=image_1, + denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=1, rand_device="cuda", +) +image_2.save("image_2.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py new file mode 100644 index 0000000000000000000000000000000000000000..da0d7cacb442fa1d8c244ec56ccf7d319a9c8055 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py @@ -0,0 +1,144 @@ +import random +import torch +from PIL import Image, ImageDraw, ImageFont +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="DiffSynth-Studio/Eligen", origin_file_pattern="model_bf16.safetensors"), alpha=1) + +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) + +# example 3 +global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning," +entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"] +example(pipe, [27], 3, global_prompt, entity_prompts) + +# example 4 +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +example(pipe, [21], 4, global_prompt, entity_prompts) + +# example 5 +global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere." +entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"] +example(pipe, [0], 5, global_prompt, entity_prompts) + +# example 6 +global_prompt = "Snow White and the 6 Dwarfs." +entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"] +example(pipe, [8], 6, global_prompt, entity_prompts) + +# example 7, same prompt with different seeds +seeds = range(5, 9) +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts) diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..59f2e9f9006be9186d5181ab397fa48407b52f1a --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin", **vram_config), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +origin_prompt = "a rabbit in a garden, colorful flowers" +image = pipe(prompt=origin_prompt, height=1280, width=960, seed=42) +image.save("style image.jpg") + +image = pipe(prompt="A piggy", height=1280, width=960, seed=42, + ipadapter_images=[image], ipadapter_scale=0.7) +image.save("A piggy.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py new file mode 100644 index 0000000000000000000000000000000000000000..119856afd98f2c2b088debc940bbcc1d692d1d9c --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py @@ -0,0 +1,73 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from modelscope import dataset_snapshot_download +from modelscope import snapshot_download +from PIL import Image +import numpy as np + + +# This model has additional requirements. +# Please install the following packages. +# pip install facexlib insightface onnxruntime +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +snapshot_download( + "ByteDance/InfiniteYou", + allow_file_pattern="supports/insightface/models/antelopev2/*", + local_dir="models/ByteDance/InfiniteYou", +) +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin", **vram_config), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/infiniteyou/*", +) + +height, width = 1024, 1024 +controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8)) +controlnet_inputs = [ControlNetInput(image=controlnet_image, scale=1.0, processor_id="None")] + +prompt = "A man, portrait, cinematic" +id_image = "data/examples/infiniteyou/man.jpg" +id_image = Image.open(id_image).convert('RGB') +image = pipe( + prompt=prompt, seed=1, + infinityou_id_image=id_image, infinityou_guidance=1.0, + controlnet_inputs=controlnet_inputs, + num_inference_steps=50, embedded_guidance=3.5, + height=height, width=width, +) +image.save("man.jpg") + +prompt = "A woman, portrait, cinematic" +id_image = "data/examples/infiniteyou/woman.jpg" +id_image = Image.open(id_image).convert('RGB') +image = pipe( + prompt=prompt, seed=1, + infinityou_id_image=id_image, infinityou_guidance=1.0, + controlnet_inputs=controlnet_inputs, + num_inference_steps=50, embedded_guidance=3.5, + height=height, width=width, +) +image.save("woman.jpg") \ No newline at end of file diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5928af0edf7832edc5126618ce8e52c04fd9f6b6 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py @@ -0,0 +1,49 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") +pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA. + +# Empty prompt can automatically activate LoRA capabilities. +image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) +image.save("image_1.jpg") + +image = pipe(prompt="", seed=0) +image.save("image_1_origin.jpg") + +# Prompt without trigger words can also activate LoRA capabilities. +image = pipe(prompt="a car", seed=0, lora_encoder_inputs=lora) +image.save("image_2.jpg") + +image = pipe(prompt="a car", seed=0,) +image.save("image_2_origin.jpg") + +# Adjust the activation intensity through the scale parameter. +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=1.0) +image.save("image_3.jpg") + +image = pipe(prompt="a cat", seed=0, lora_encoder_inputs=lora, lora_encoder_scale=0.5) +image.save("image_3_scale.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ce587cdac6d12923050b01c6003133274ba5f7bd --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Fusion.py @@ -0,0 +1,38 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +pipe.enable_lora_merger() + +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), +) +pipe.load_lora( + pipe.dit, + ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), +) +image = pipe(prompt="a cat", seed=0) +image.save("image_fused.jpg") diff --git a/examples/flux/model_inference_low_vram/FLUX.1-dev.py b/examples/flux/model_inference_low_vram/FLUX.1-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..ffaf1813fe53a19bdb21f44313ddbe388a124cb1 --- /dev/null +++ b/examples/flux/model_inference_low_vram/FLUX.1-dev.py @@ -0,0 +1,37 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe(prompt=prompt, seed=0) +image.save("flux.jpg") + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + seed=0, cfg_scale=2, num_inference_steps=50, +) +image.save("flux_cfg.jpg") diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3050fbfa92bff5ddf6a7a9da5c5ab3da67591a --- /dev/null +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -0,0 +1,48 @@ +import importlib +import torch +from PIL import Image +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + nexus_gen_processor_config=ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("cat_crown.jpg") diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py new file mode 100644 index 0000000000000000000000000000000000000000..8372fcb822d0f0862a4352a77d0c329f60128b5f --- /dev/null +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py @@ -0,0 +1,43 @@ +import importlib +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config), + ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +prompt = "一只可爱的猫咪" +image = pipe( + prompt=prompt, negative_prompt="", + seed=0, cfg_scale=3, num_inference_steps=50, + height=1024, width=1024, +) +image.save("cat.jpg") diff --git a/examples/flux/model_inference_low_vram/Step1X-Edit.py b/examples/flux/model_inference_low_vram/Step1X-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3bde88ae112eced27d9ca7d7d10743c1df803f --- /dev/null +++ b/examples/flux/model_inference_low_vram/Step1X-Edit.py @@ -0,0 +1,43 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image +import numpy as np + + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors", **vram_config), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors", **vram_config), + ], + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) +image = pipe( + prompt="draw red flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=1, rand_device='cuda' +) +image.save("image_1.jpg") + +image = pipe( + prompt="add more flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=2, rand_device='cuda' +) +image.save("image_2.jpg") diff --git a/examples/flux/model_training/full/FLEX.2-preview.sh b/examples/flux/model_training/full/FLEX.2-preview.sh new file mode 100644 index 0000000000000000000000000000000000000000..fffe929a902733dd0615bbb9f0e1a39548ebebb7 --- /dev/null +++ b/examples/flux/model_training/full/FLEX.2-preview.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 200 \ + --model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLEX.2-preview_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh b/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..4938f10d4467025aff5ec873f44d6da6df06b811 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \ + --data_file_keys "image,kontext_images" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-Kontext-dev_full" \ + --trainable_models "dit" \ + --extra_inputs "kontext_images" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-Krea-dev.sh b/examples/flux/model_training/full/FLUX.1-Krea-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..053b0fa97adb031a411faa0bf0ad05eaabad88ba --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-Krea-dev.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-Krea-dev:flux1-krea-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-Krea-dev_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-AttriCtrl.sh b/examples/flux/model_training/full/FLUX.1-dev-AttriCtrl.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba620fd618d2be4d9d4acc57c3b1ae2554081860 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-AttriCtrl.sh @@ -0,0 +1,14 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_attrictrl.csv \ + --data_file_keys "image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.value_controller.encoders.0." \ + --output_path "./models/train/FLUX.1-dev-AttriCtrl_full" \ + --trainable_models "value_controller" \ + --extra_inputs "value_controller_inputs" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh b/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh new file mode 100644 index 0000000000000000000000000000000000000000..d362313a89faf9ccca4b1b1302b2cf7e0e8c198a --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \ + --data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.controlnet.models.0." \ + --output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image,controlnet_inpaint_mask" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh b/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0a56af3d507707feb745714a5ce35503de6d297 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.controlnet.models.0." \ + --output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image,controlnet_processor_id" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh b/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh new file mode 100644 index 0000000000000000000000000000000000000000..85a0228fa2582b0e42a5b9132e95b5571bc8b0bf --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.controlnet.models.0." \ + --output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_full" \ + --trainable_models "controlnet" \ + --extra_inputs "controlnet_image" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh b/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh new file mode 100644 index 0000000000000000000000000000000000000000..6db5e793f10fe1544cd61356c6dabb14a0da4059 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh @@ -0,0 +1,14 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \ + --data_file_keys "image,ipadapter_images" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.ipadapter." \ + --output_path "./models/train/FLUX.1-dev-IP-Adapter_full" \ + --trainable_models "ipadapter" \ + --extra_inputs "ipadapter_images" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh b/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh new file mode 100644 index 0000000000000000000000000000000000000000..789879506f0272532dd2e7a943fec0b435fcaadd --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \ + --data_file_keys "image,controlnet_image,infinityou_id_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe." \ + --output_path "./models/train/FLUX.1-dev-InfiniteYou_full" \ + --trainable_models "controlnet,image_proj_model" \ + --extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh b/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh new file mode 100644 index 0000000000000000000000000000000000000000..b788434f69e45fbcbcc1fab240b48fa8c387874a --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh @@ -0,0 +1,14 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_lora_encoder.csv \ + --data_file_keys "image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev:model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.lora_encoder." \ + --output_path "./models/train/FLUX.1-dev-LoRA-Encoder_full" \ + --trainable_models "lora_encoder" \ + --extra_inputs "lora_encoder_inputs" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev.sh b/examples/flux/model_training/full/FLUX.1-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..a9f582009ed542398b630affe7fa9bb64621251f --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/Nexus-Gen.sh b/examples/flux/model_training/full/Nexus-Gen.sh new file mode 100644 index 0000000000000000000000000000000000000000..6f2960d2a5776fe3696a34e1504278240bfc155a --- /dev/null +++ b/examples/flux/model_training/full/Nexus-Gen.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \ + --data_file_keys "image,nexus_gen_reference_image" \ + --max_pixels 262144 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-NexusGen-Edit_full" \ + --trainable_models "dit" \ + --extra_inputs "nexus_gen_reference_image" \ + --use_gradient_checkpointing_offload diff --git a/examples/flux/model_training/full/Step1X-Edit.sh b/examples/flux/model_training/full/Step1X-Edit.sh new file mode 100644 index 0000000000000000000000000000000000000000..03ddfdadb0d7f8a3c461e5526ca92eef8b8aa3fd --- /dev/null +++ b/examples/flux/model_training/full/Step1X-Edit.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \ + --data_file_keys "image,step1x_reference_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Step1X-Edit_full" \ + --trainable_models "dit" \ + --extra_inputs "step1x_reference_image" \ + --use_gradient_checkpointing_offload diff --git a/examples/flux/model_training/full/accelerate_config.yaml b/examples/flux/model_training/full/accelerate_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83280f73f315a32eccb065f351d66b4b2678759d --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux/model_training/full/accelerate_config_zero2offload.yaml b/examples/flux/model_training/full/accelerate_config_zero2offload.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a75f3d91eeae160409650b482e5383ac26b297b --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config_zero2offload.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: 'cpu' + offload_param_device: 'cpu' + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux/model_training/lora/FLEX.2-preview.sh b/examples/flux/model_training/lora/FLEX.2-preview.sh new file mode 100644 index 0000000000000000000000000000000000000000..444e91c1bd56cdba1d1a8a9569aa0b371ab53a5a --- /dev/null +++ b/examples/flux/model_training/lora/FLEX.2-preview.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLEX.2-preview_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh b/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..f45707ef1908c35d36c583baa079973dd036ed44 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \ + --data_file_keys "image,kontext_images" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-Kontext-dev_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "kontext_images" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh b/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..cea0009aeac2b8b55607a3b16126a46514a3fae0 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-Krea-dev:flux1-krea-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-Krea-dev_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-AttriCtrl.sh b/examples/flux/model_training/lora/FLUX.1-dev-AttriCtrl.sh new file mode 100644 index 0000000000000000000000000000000000000000..8e6d8c9018f64f9fb44a9bc69905dfd429866b43 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-AttriCtrl.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_attrictrl.csv \ + --data_file_keys "image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,DiffSynth-Studio/AttriCtrl-FLUX.1-Dev:models/brightness.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-AttriCtrl_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "value_controller_inputs" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh b/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d7afc6c7f14a6fd7b8ab7f1fc45bae10f9fed77 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \ + --data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image,controlnet_inpaint_mask" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh b/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh new file mode 100644 index 0000000000000000000000000000000000000000..e1272c2d17a33eff8124ae8148c842284759d08b --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image,controlnet_processor_id" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh b/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh new file mode 100644 index 0000000000000000000000000000000000000000..398e270ff874d17079afc044cbe77f22db352353 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \ + --data_file_keys "image,controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh b/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh new file mode 100644 index 0000000000000000000000000000000000000000..0579cd2063b6fc8e9ff844cdca2043639590889f --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_eligen.json \ + --data_file_keys "image,eligen_entity_masks" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-EliGen_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "eligen_entity_masks,eligen_entity_prompts" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh b/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh new file mode 100644 index 0000000000000000000000000000000000000000..e11007530a4835cb8b670c07e358d9c22f2295bc --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \ + --data_file_keys "image,ipadapter_images" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-IP-Adapter_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "ipadapter_images" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh b/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh new file mode 100644 index 0000000000000000000000000000000000000000..493830199d9fd8961e2e5bd1bbd7bb8083775edc --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \ + --data_file_keys "image,controlnet_image,infinityou_id_image" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-InfiniteYou_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev.sh b/examples/flux/model_training/lora/FLUX.1-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..5118857d7452b14ac17449d0d2daef8f2d348c8a --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/Nexus-Gen.sh b/examples/flux/model_training/lora/Nexus-Gen.sh new file mode 100644 index 0000000000000000000000000000000000000000..b98bd581446dd22f59ede2437a28b61025206fcb --- /dev/null +++ b/examples/flux/model_training/lora/Nexus-Gen.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \ + --data_file_keys "image,nexus_gen_reference_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-NexusGen-Edit_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "nexus_gen_reference_image" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/Step1X-Edit.sh b/examples/flux/model_training/lora/Step1X-Edit.sh new file mode 100644 index 0000000000000000000000000000000000000000..a7f1d8fa12a353f5d0c070ef48dd1789154bcd35 --- /dev/null +++ b/examples/flux/model_training/lora/Step1X-Edit.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \ + --data_file_keys "image,step1x_reference_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Step1X-Edit_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --extra_inputs "step1x_reference_image" \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4903b976c356070abb066b71a3be6b8cc5d1308b --- /dev/null +++ b/examples/flux/model_training/train.py @@ -0,0 +1,193 @@ +import torch, os, argparse, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class FluxTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_1_path=None, tokenizer_2_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_1_config = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/") if tokenizer_1_path is None else ModelConfig(tokenizer_1_path) + tokenizer_2_config = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/") if tokenizer_2_path is None else ModelConfig(tokenizer_2_path) + self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_1_config=tokenizer_1_config, tokenizer_2_config=tokenizer_2_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "direct_distill:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "embedded_guidance": 1, + "t5_sequence_length": 512, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def flux_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--tokenizer_1_path", type=str, default=None, help="Path to CLIP tokenizer.") + parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to T5 tokenizer.") + parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") + return parser + + +def convert_lora_format(state_dict, alpha=None): + prefix_rename_dict = { + "single_blocks": "lora_unet_single_blocks", + "blocks": "lora_unet_double_blocks", + } + middle_rename_dict = { + "norm.linear": "modulation_lin", + "to_qkv_mlp": "linear1", + "proj_out": "linear2", + "norm1_a.linear": "img_mod_lin", + "norm1_b.linear": "txt_mod_lin", + "attn.a_to_qkv": "img_attn_qkv", + "attn.b_to_qkv": "txt_attn_qkv", + "attn.a_to_out": "img_attn_proj", + "attn.b_to_out": "txt_attn_proj", + "ff_a.0": "img_mlp_0", + "ff_a.2": "img_mlp_2", + "ff_b.0": "txt_mlp_0", + "ff_b.2": "txt_mlp_2", + } + suffix_rename_dict = { + "lora_B.weight": "lora_up.weight", + "lora_A.weight": "lora_down.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + names = name.split(".") + if names[-2] != "lora_A" and names[-2] != "lora_B": + names.pop(-2) + prefix = names[0] + middle = ".".join(names[2:-2]) + suffix = ".".join(names[-2:]) + block_id = names[1] + if middle not in middle_rename_dict: + continue + rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix] + state_dict_[rename] = param + if rename.endswith("lora_up.weight"): + lora_alpha = alpha if alpha is not None else param.shape[-1] + state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0] + return state_dict_ + + +if __name__ == "__main__": + parser = flux_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = FluxTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_1_path=args.tokenizer_1_path, + tokenizer_2_path=args.tokenizer_2_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device=accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + state_dict_converter=convert_lora_format if args.align_to_opensource_format else lambda x:x, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/flux/model_training/validate_full/FLEX.2-preview.py b/examples/flux/model_training/validate_full/FLEX.2-preview.py new file mode 100644 index 0000000000000000000000000000000000000000..6e44fd6e568f0378c57949c37d118f90ffa46ab0 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLEX.2-preview.py @@ -0,0 +1,20 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLEX.2-preview_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe(prompt="dog,white and brown dog, sitting on wall, under pink flowers", seed=0) +image.save("image_FLEX.2-preview_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py b/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3604780b04aa3ee4c2b66ecb1eeca597c07d66 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py @@ -0,0 +1,26 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-Kontext-dev_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe( + prompt="Make the dog turn its head around.", + kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-Kontext-dev_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py b/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..044055e271d1e4d6c59f1af17b4067bb5e5ea8b5 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py @@ -0,0 +1,20 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-Krea-dev_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe(prompt="a dog", seed=0) +image.save("image_FLUX.1-Krea-dev_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-AttriCtrl.py b/examples/flux/model_training/validate_full/FLUX.1-dev-AttriCtrl.py new file mode 100644 index 0000000000000000000000000000000000000000..74a0dcaf30bd7965670786f392d084df4acab91d --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-AttriCtrl.py @@ -0,0 +1,21 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-AttriCtrl_full/epoch-0.safetensors") +pipe.value_controller.encoders[0].load_state_dict(state_dict) + +image = pipe(prompt="a cat", seed=0, value_controller_inputs=0.1, rand_device="cuda") +image.save("image_FLUX.1-dev-AttriCtrl_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py b/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py new file mode 100644 index 0000000000000000000000000000000000000000..e27d142935e13813a071810e2eb62337f1bf899e --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py @@ -0,0 +1,31 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full/epoch-0.safetensors") +pipe.controlnet.models[0].load_state_dict(state_dict) + +image = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/inpaint/image_1.jpg"), + inpaint_mask=Image.open("data/example_image_dataset/inpaint/mask.jpg"), + scale=0.9 + )], + height=1024, width=1024, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-Controlnet-Inpainting-Beta_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py b/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..1db06cbca46c6d2878cc1a7e4dcca6cd01bc555e --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py @@ -0,0 +1,31 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Union-alpha_full/epoch-0.safetensors") +pipe.controlnet.models[0].load_state_dict(state_dict) + +image = pipe( + prompt="a dog", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/canny/image_1.jpg"), + scale=0.9, + processor_id="canny", + )], + height=768, width=768, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-Controlnet-Union-alpha_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py b/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..330e2614b8d693fa97396fad25bbaa739d150889 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py @@ -0,0 +1,30 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Upscaler_full/epoch-0.safetensors") +pipe.controlnet.models[0].load_state_dict(state_dict) + +image = pipe( + prompt="a dog", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/upscale/image_1.jpg"), + scale=0.9 + )], + height=768, width=768, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-Controlnet-Upscaler_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7c15dedbe926795504c5a5b5178877ec70bbd712 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py @@ -0,0 +1,28 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors") +pipe.ipadapter.load_state_dict(state_dict) + +image = pipe( + prompt="a dog", + ipadapter_images=Image.open("data/example_image_dataset/1.jpg"), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-dev-IP-Adapter_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py new file mode 100644 index 0000000000000000000000000000000000000000..311c5b939cfc922b59620be9d19b7e15ad087c8c --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py @@ -0,0 +1,33 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-InfiniteYou_full/epoch-0.safetensors") +state_dict_projector = {i.replace("image_proj_model.", ""): state_dict[i] for i in state_dict if i.startswith("image_proj_model.")} +pipe.image_proj_model.load_state_dict(state_dict_projector) +state_dict_controlnet = {i.replace("controlnet.models.0.", ""): state_dict[i] for i in state_dict if i.startswith("controlnet.models.0.")} +pipe.controlnet.models[0].load_state_dict(state_dict_controlnet) + +image = pipe( + prompt="a man with a red hat", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/infiniteyou/image_1.jpg"), + )], + height=1024, width=1024, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-InfiniteYou_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5c1d206ac341fcc256d7ebd081f81278321ba280 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-LoRA-Encoder_full/epoch-0.safetensors") +pipe.lora_encoder.load_state_dict(state_dict) + +lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") +pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA. + +image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) +image.save("image_FLUX.1-dev-LoRA-Encoder_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev.py b/examples/flux/model_training/validate_full/FLUX.1-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f9f7e4a7e87efc79badebcccb6433ec734887f --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev.py @@ -0,0 +1,20 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe(prompt="a dog", seed=0) +image.save("image_FLUX.1-dev_full.jpg") diff --git a/examples/flux/model_training/validate_full/Nexus-Gen.py b/examples/flux/model_training/validate_full/Nexus-Gen.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2c2bc2bd908bf6845cfbd507291c330a39a3d2 --- /dev/null +++ b/examples/flux/model_training/validate_full/Nexus-Gen.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-NexusGen-Edit_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB") +prompt = "Add a pair of sunglasses." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("NexusGen-Edit_full.jpg") diff --git a/examples/flux/model_training/validate_full/Step1X-Edit.py b/examples/flux/model_training/validate_full/Step1X-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..feaac7aabd5fd6f8b1fe211b0d5031561ebc9ffc --- /dev/null +++ b/examples/flux/model_training/validate_full/Step1X-Edit.py @@ -0,0 +1,25 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/Step1X-Edit_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe( + prompt="Make the dog turn its head around.", + step1x_reference_image=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)), + height=768, width=768, cfg_scale=6, + seed=0 +) +image.save("image_Step1X-Edit_full.jpg") diff --git a/examples/flux/model_training/validate_lora/FLEX.2-preview.py b/examples/flux/model_training/validate_lora/FLEX.2-preview.py new file mode 100644 index 0000000000000000000000000000000000000000..a9059181d3a18f351042dc16a1052d6f4a0aac99 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLEX.2-preview.py @@ -0,0 +1,18 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLEX.2-preview_lora/epoch-4.safetensors", alpha=1) + +image = pipe(prompt="dog,white and brown dog, sitting on wall, under pink flowers", seed=0) +image.save("image_FLEX.2-preview_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py b/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e681eb551b9d83b4d49e3ab47b4cc5346c6978 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-Kontext-dev_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="Make the dog turn its head around.", + kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-Kontext-dev_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py b/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..7df61cc3eccca1972c3900f7e1c29a87526f1253 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py @@ -0,0 +1,18 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Krea-dev", origin_file_pattern="flux1-krea-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-Krea-dev_lora/epoch-4.safetensors", alpha=1) + +image = pipe(prompt="a dog", seed=0) +image.save("image_FLUX.1-Krea-dev_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-AttriCtrl.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-AttriCtrl.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb81d25fd8244b84c514c868b04862f604e9a07 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-AttriCtrl.py @@ -0,0 +1,19 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-AttriCtrl_lora/epoch-3.safetensors", alpha=1) + +image = pipe(prompt="a cat", seed=0, value_controller_inputs=0.1, rand_device="cuda") +image.save("image_FLUX.1-dev-AttriCtrl_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py new file mode 100644 index 0000000000000000000000000000000000000000..cbedf7c1ea5b449fdef970abc4bcda8c6c5f855f --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py @@ -0,0 +1,29 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/inpaint/image_1.jpg"), + inpaint_mask=Image.open("data/example_image_dataset/inpaint/mask.jpg"), + scale=0.9 + )], + height=1024, width=1024, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-Controlnet-Inpainting-Beta_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..c64c40e3ebc37f69b8de8c4f0a351bc6f9474280 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py @@ -0,0 +1,29 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Union-alpha_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="a dog", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/canny/image_1.jpg"), + scale=0.9, + processor_id="canny", + )], + height=768, width=768, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-Controlnet-Union-alpha_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..935c6fc9bdb39c838dd184ef3d29eb75dbb60998 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py @@ -0,0 +1,28 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Upscaler_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="a dog", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/upscale/image_1.jpg"), + scale=0.9 + )], + height=768, width=768, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-Controlnet-Upscaler_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py new file mode 100644 index 0000000000000000000000000000000000000000..b252269c8e1fc8a18a3b5dfde58eeb69dee3cc84 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-EliGen_lora/epoch-4.safetensors", alpha=1) + +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))] +# generate image +image = pipe( + prompt=global_prompt, + cfg_scale=1.0, + num_inference_steps=50, + embedded_guidance=3.5, + seed=42, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, +) +image.save(f"EliGen_lora.png") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..31c295bc9fde56f20566c1f81dbcaecc2ede97f2 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py @@ -0,0 +1,26 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="dog,white and brown dog, sitting on wall, under pink flowers", + ipadapter_images=Image.open("data/example_image_dataset/1.jpg"), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-dev-IP-Adapter_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py new file mode 100644 index 0000000000000000000000000000000000000000..9a76170b86ac2201e997b17b326231d70b65877d --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py @@ -0,0 +1,28 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig, ControlNetInput +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"), + ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-InfiniteYou_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="a man with a red hat", + controlnet_inputs=[ControlNetInput( + image=Image.open("data/example_image_dataset/infiniteyou/image_1.jpg"), + )], + height=1024, width=1024, + seed=0, rand_device="cuda", +) +image.save("image_FLUX.1-dev-InfiniteYou_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev.py b/examples/flux/model_training/validate_lora/FLUX.1-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..110a65cdb82e6bca2b21224d55e2cabd50d244a3 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev.py @@ -0,0 +1,18 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev_lora/epoch-4.safetensors", alpha=1) + +image = pipe(prompt="a dog", seed=0) +image.save("image_FLUX.1-dev_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/Nexus-Gen.py b/examples/flux/model_training/validate_lora/Nexus-Gen.py new file mode 100644 index 0000000000000000000000000000000000000000..447ed8f09f5857fcc3f4e4c16cfd66db4f84c768 --- /dev/null +++ b/examples/flux/model_training/validate_lora/Nexus-Gen.py @@ -0,0 +1,26 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-NexusGen-Edit_lora/epoch-4.safetensors", alpha=1) + +ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB") +prompt = "Add a pair of sunglasses." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=1.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("NexusGen-Edit_lora.jpg") diff --git a/examples/flux/model_training/validate_lora/Step1X-Edit.py b/examples/flux/model_training/validate_lora/Step1X-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..e89ff9868e2795d9af79b4935f934b7c0c2e5416 --- /dev/null +++ b/examples/flux/model_training/validate_lora/Step1X-Edit.py @@ -0,0 +1,23 @@ +import torch +from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Step1X-Edit_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="Make the dog turn its head around.", + step1x_reference_image=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)), + height=768, width=768, cfg_scale=6, + seed=0 +) +image.save("image_Step1X-Edit_lora.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-dev.py b/examples/flux2/model_inference/FLUX.2-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5176cc97193bf3a0b99c988a4b07a37bf715f4 --- /dev/null +++ b/examples/flux2/model_inference/FLUX.2-dev.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), +) +prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image_FLUX.2-dev.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-dev.py b/examples/flux2/model_inference_low_vram/FLUX.2-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..41d588e873e3a434515d6a551aff15223efadb66 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-dev.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image.jpg") \ No newline at end of file diff --git a/examples/flux2/model_training/lora/FLUX.2-dev.sh b/examples/flux2/model_training/lora/FLUX.2-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..4b1e74b4d6f23e6af616a09f0c2980558c2a973c --- /dev/null +++ b/examples/flux2/model_training/lora/FLUX.2-dev.sh @@ -0,0 +1,32 @@ +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task "sft:data_process" + +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task "sft:train" diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..30408a12ba34dfade13949af1a8665f0c7356274 --- /dev/null +++ b/examples/flux2/model_training/train.py @@ -0,0 +1,143 @@ +import torch, os, argparse, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class Flux2ImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "direct_distill:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "embedded_guidance": 1.0, + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def flux2_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + return parser + + +if __name__ == "__main__": + parser = flux2_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = Flux2ImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device=accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-dev.py b/examples/flux2/model_training/validate_lora/FLUX.2-dev.py new file mode 100644 index 0000000000000000000000000000000000000000..e67e2a7bfeb9b7def9f1f61f84135ed49f9a6fff --- /dev/null +++ b/examples/flux2/model_training/validate_lora/FLUX.2-dev.py @@ -0,0 +1,28 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/FLUX.2-dev-LoRA-splited/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt, seed=0) +image.save("image_FLUX.2-dev_lora.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py new file mode 100644 index 0000000000000000000000000000000000000000..85b9b96886a3a48bdf30319a99557a952721ab81 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py @@ -0,0 +1,31 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="canny/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328)) + +prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py new file mode 100644 index 0000000000000000000000000000000000000000..6676868ad0e24d09d962211dda758f29cf7180f2 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py @@ -0,0 +1,32 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) + +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328)) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb98e0a348ba5b9e9748496c05bc8fc8aef154e --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +prompt = "a cat with sunglasses" +controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328)) +image = pipe( + prompt, seed=0, + input_image=controlnet_image, inpaint_mask=inpaint_mask, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], + num_inference_steps=40, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py b/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py new file mode 100644 index 0000000000000000000000000000000000000000..007538f76b5c63c3b72ab7d2cd592773677aabb5 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py @@ -0,0 +1,25 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +from modelscope import snapshot_download +import torch, math + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +snapshot_download("MusePublic/Qwen-Image-Distill", allow_file_pattern="qwen_image_distill_3step.safetensors", cache_dir="models") +lora_state_dict = load_state_dict("models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors") +lora_state_dict = {i.replace("base_model.model.", ""): j for i, j in lora_state_dict.items()} +pipe.load_lora(pipe.dit, state_dict=lora_state_dict) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5)) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000000000000000000000000000000000000..c13a417f5aba55820b8111b872e8cf7be26e6ed1 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py b/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..aad1fdd7eb222e2d87fb57180c2ffb5d1db4c3ec --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import snapshot_download +import torch + +snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA") +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors") + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") \ No newline at end of file diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py new file mode 100644 index 0000000000000000000000000000000000000000..5dcffbbf3112f84e7dcb6ecc2e73f4fb3ecb4af4 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py @@ -0,0 +1,31 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from PIL import Image +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +image_1 = pipe(prompt="一位少女", seed=0, num_inference_steps=40, height=1328, width=1024) +image_1.save("image1.jpg") + +image_2 = pipe(prompt="一位老人", seed=0, num_inference_steps=40, height=1328, width=1024) +image_2.save("image2.jpg") + +prompt = "生成这两个人的合影" +edit_image = [Image.open("image1.jpg"), Image.open("image2.jpg")] +image_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True) +image_3.save("image3.jpg") + +# Qwen-Image-Edit-2509 is a multi-image editing model. +# Please use a list to input `edit_image`, even if the input contains only one image. +# edit_image = [Image.open("image.jpg")] +# Please do not input the image directly. +# edit_image = Image.open("image.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py new file mode 100644 index 0000000000000000000000000000000000000000..c18eaa8fd329fda7f4784fe4f3ce22309bbda2ef --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py @@ -0,0 +1,25 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from modelscope import snapshot_download + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors") + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768) +image.save("image.jpg") + +prompt = "将裙子变成粉色" +image = image.resize((512, 384)) +image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False) +image.save(f"image2.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit.py b/examples/qwen_image/model_inference/Qwen-Image-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..8a47756ffb71e778d204303bbb813c0fa97d987e --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit.py @@ -0,0 +1,25 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +input_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024) +input_image.save("image1.jpg") + +prompt = "将裙子改为粉色" +# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio +image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True) +image.save(f"image2.jpg") + +# edit_image_auto_resize=False: do not resize input image +image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False) +image.save(f"image3.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py b/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf0a12afd80cbde3d711f0942d593afc38591a0 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py @@ -0,0 +1,114 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download, snapshot_download +import random + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + + +def example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280): + dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/eligen/poster/example_{example_id}/*.png" + ) + masks = [ + Image.open(f"./data/examples/eligen/poster/example_{example_id}/{i}.png").convert('RGB').resize((width, height)) + for i in range(len(entity_prompts)) + ] + negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=40, + seed=seed, + height=height, + width=width, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_poster_example_{example_id}_{seed}.png") + image = Image.new("RGB", (width, height), (0, 0, 0)) + visualize_masks(image, masks, entity_prompts, f"eligen_poster_example_{example_id}_mask_{seed}.png") + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +snapshot_download( + "DiffSynth-Studio/Qwen-Image-EliGen-Poster", + local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-Poster", + allow_file_pattern="model.safetensors", +) +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors") +global_prompt = "一张以柔粉紫为背景的海报,左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”,粉紫色椭圆框内白色小字:“图像精确分区控制模型”。右侧有一只小兔子在拆礼物,旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)。背景有一些白云点缀。整体风格卡通可爱,传达节日惊喜的主题。" +entity_prompts = ["粉紫色文字“Qwen-Image EliGen-Poster”", "粉紫色椭圆框内白色小字:“图像精确分区控制模型”", "一只小兔子在拆礼物,小兔子旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)"] +seed = [42] +example(pipe, seed, 1, global_prompt, entity_prompts) diff --git a/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py b/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py new file mode 100644 index 0000000000000000000000000000000000000000..82bab2d0236c058216f23406be1220db3096cdfc --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py @@ -0,0 +1,106 @@ +import torch +import random +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download, snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))] + negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=40, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen-V2", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-V2", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors") + +seeds = [0] + +global_prompt = "写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background." +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts) + +global_prompt = "写实摄影风格, 细节丰富。街头一位漂亮的女孩,穿着衬衫和短裤,手持写有“实体控制”的标牌,背景是繁忙的城市街道,阳光明媚,行人匆匆。" +entity_prompts = ["一个漂亮的女孩", "标牌 '实体控制'", "短裤", "衬衫"] +example(pipe, seeds, 4, global_prompt, entity_prompts) diff --git a/examples/qwen_image/model_inference/Qwen-Image-EliGen.py b/examples/qwen_image/model_inference/Qwen-Image-EliGen.py new file mode 100644 index 0000000000000000000000000000000000000000..d49112317f3278d56df045cec7f26155773e869f --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-EliGen.py @@ -0,0 +1,107 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download, snapshot_download +import random + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=30, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors") + +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) diff --git a/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py b/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py new file mode 100644 index 0000000000000000000000000000000000000000..516ee08258c66ce86f40a13034f47d81ccfa62e1 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py @@ -0,0 +1,35 @@ +from PIL import Image +import torch +from modelscope import dataset_snapshot_download, snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.utils.controlnet import Annotator + +allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"] +snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern) + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +snapshot_download("DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", local_dir="models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors") + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/qwen-image-context-control/image.jpg") +origin_image = Image.open("data/examples/qwen-image-context-control/image.jpg").resize((1024, 1024)) +annotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal'] +for annotator_id in annotator_ids: + annotator = Annotator(processor_id=annotator_id, device="cuda") + control_image = annotator(origin_image) + control_image.save(f"{annotator.processor_id}.png") + + control_prompt = "Context_Control. " + prompt = f"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞,背景是梦幻的星空,光影交错,细节精致。" + negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴" + image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024) + image.save(f"image_{annotator.processor_id}.png") diff --git a/examples/qwen_image/model_inference/Qwen-Image-i2L.py b/examples/qwen_image/model_inference/Qwen-Image-i2L.py new file mode 100644 index 0000000000000000000000000000000000000000..87061d829ff35742d4ad21d9e870c48f842c8077 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-i2L.py @@ -0,0 +1,110 @@ +from diffsynth.pipelines.qwen_image import ( + QwenImagePipeline, ModelConfig, + QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode +) +from diffsynth.utils.lora import merge_lora +from diffsynth import load_state_dict +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + + +def demo_style(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/1/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/style/1/0.jpg"), + Image.open("data/examples/assets/style/1/1.jpg"), + Image.open("data/examples/assets/style/1/2.jpg"), + Image.open("data/examples/assets/style/1/3.jpg"), + Image.open("data/examples/assets/style/1/4.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + save_file(lora, "model_style.safetensors") + + +def demo_coarse_fine_bias(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/lora/3/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/lora/3/0.jpg"), + Image.open("data/examples/assets/lora/3/1.jpg"), + Image.open("data/examples/assets/lora/3/2.jpg"), + Image.open("data/examples/assets/lora/3/3.jpg"), + Image.open("data/examples/assets/lora/3/4.jpg"), + Image.open("data/examples/assets/lora/3/5.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors") + lora_bias.download_if_necessary() + lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda") + lora = merge_lora([lora, lora_bias]) + save_file(lora, "model_coarse_fine_bias.safetensors") + + +def generate_image(lora_path, prompt, seed): + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + ) + pipe.load_lora(pipe.dit, lora_path) + image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50) + return image + + +demo_style() +image = generate_image("model_style.safetensors", "a cat", 0) +image.save("image_1.jpg") + +demo_coarse_fine_bias() +image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1) +image.save("image_2.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image.py b/examples/qwen_image/model_inference/Qwen-Image.py new file mode 100644 index 0000000000000000000000000000000000000000..275cfba8c33db6971b87493f2966b6b235744588 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2deb2f1cd05e13073f7942b195b268a36d9a74 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py @@ -0,0 +1,42 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="canny/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328)) + +prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py new file mode 100644 index 0000000000000000000000000000000000000000..ddbb0818e1d7c89a324a10236fe6037cd84bb146 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py @@ -0,0 +1,43 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) + +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328)) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..ffee149e110bdc0959044209e6737567298697fd --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +prompt = "a cat with sunglasses" +controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328)) +image = pipe( + prompt, seed=0, + input_image=controlnet_image, inpaint_mask=inpaint_mask, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], + num_inference_steps=40, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9566742cbdffea694613d353ef2c8f30471a41 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py @@ -0,0 +1,36 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.core import load_state_dict +from modelscope import snapshot_download +import torch, math + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, # bfloat16 is recommended. + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, # bfloat16 is recommended. + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +snapshot_download("MusePublic/Qwen-Image-Distill", allow_file_pattern="qwen_image_distill_3step.safetensors", cache_dir="models") +lora_state_dict = load_state_dict("models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors", device="cuda", torch_dtype=torch.bfloat16) +lora_state_dict = {i.replace("base_model.model.", "").replace(".weight", ".default.weight"): j for i, j in lora_state_dict.items()} +pipe.load_lora(pipe.dit, state_dict=lora_state_dict, hotload=True) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5)) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000000000000000000000000000000000000..8f99e1c2e2c96e30e566622bf70b19092b075513 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py @@ -0,0 +1,28 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8f2323131b8c9a658986d7d7e2e6e6e4f3fa1e --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py @@ -0,0 +1,31 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import snapshot_download +import torch + +snapshot_download("DiffSynth-Studio/Qwen-Image-Distill-LoRA", local_dir="models/DiffSynth-Studio/Qwen-Image-Distill-LoRA") +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Distill-LoRA/model.safetensors", hotload=True) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") \ No newline at end of file diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py new file mode 100644 index 0000000000000000000000000000000000000000..97357aae623e71a806d992dd82e8adc6817f963e --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py @@ -0,0 +1,43 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image_1 = pipe(prompt="一位少女", seed=0, num_inference_steps=40, height=1328, width=1024) +image_1.save("image1.jpg") + +image_2 = pipe(prompt="一位老人", seed=0, num_inference_steps=40, height=1328, width=1024) +image_2.save("image2.jpg") + +prompt = "生成这两个人的合影" +edit_image = [Image.open("image1.jpg"), Image.open("image2.jpg")] +image_3 = pipe(prompt, edit_image=edit_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True) +image_3.save("image3.jpg") + +# Qwen-Image-Edit-2509 is a multi-image editing model. +# Please use a list to input `edit_image`, even if the input contains only one image. +# edit_image = [Image.open("image.jpg")] +# Please do not input the image directly. +# edit_image = Image.open("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py new file mode 100644 index 0000000000000000000000000000000000000000..0a849ca8bb6a69cf7aec3b7eb92253d82c6c7cf1 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py @@ -0,0 +1,37 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from modelscope import snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +snapshot_download("DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", local_dir="models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix/model.safetensors", hotload=True) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=768) +image.save("image.jpg") + +prompt = "将裙子变成粉色" +image = image.resize((512, 384)) +image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False) +image.save(f"image2.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..51193663106ae82b42d71d3ad7c301b39b8aec81 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py @@ -0,0 +1,37 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +input_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024) +input_image.save("image1.jpg") + +prompt = "将裙子改为粉色" +# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio +image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True) +image.save(f"image2.jpg") + +# edit_image_auto_resize=False: do not resize input image +image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False) +image.save(f"image3.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py new file mode 100644 index 0000000000000000000000000000000000000000..276aaa204d89acdb3554bfc7b91fa6418f751380 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py @@ -0,0 +1,125 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download, snapshot_download +import random + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + + +def example(pipe, seeds, example_id, global_prompt, entity_prompts, height=784, width=1280): + dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/eligen/poster/example_{example_id}/*.png" + ) + masks = [ + Image.open(f"./data/examples/eligen/poster/example_{example_id}/{i}.png").convert('RGB').resize((width, height)) + for i in range(len(entity_prompts)) + ] + negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=40, + seed=seed, + height=height, + width=width, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_poster_example_{example_id}_{seed}.png") + image = Image.new("RGB", (width, height), (0, 0, 0)) + visualize_masks(image, masks, entity_prompts, f"eligen_poster_example_{example_id}_mask_{seed}.png") + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +snapshot_download( + "DiffSynth-Studio/Qwen-Image-EliGen-Poster", + local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-Poster", + allow_file_pattern="model.safetensors", +) +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-Poster/model.safetensors", hotload=True) +global_prompt = "一张以柔粉紫为背景的海报,左侧有大号粉紫色文字“Qwen-Image EliGen-Poster”,粉紫色椭圆框内白色小字:“图像精确分区控制模型”。右侧有一只小兔子在拆礼物,旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)。背景有一些白云点缀。整体风格卡通可爱,传达节日惊喜的主题。" +entity_prompts = ["粉紫色文字“Qwen-Image EliGen-Poster”", "粉紫色椭圆框内白色小字:“图像精确分区控制模型”", "一只小兔子在拆礼物,小兔子旁边站着一只头顶迷你烟花发射器的小龙(卡通Q版)"] +seed = [42] +example(pipe, seed, 1, global_prompt, entity_prompts) diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py new file mode 100644 index 0000000000000000000000000000000000000000..353ff19edc96b701aebcd4610e7edfe4025a90f3 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py @@ -0,0 +1,117 @@ +import torch +import random +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download, snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB').resize((1024, 1024)) for i in range(len(entity_prompts))] + negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=40, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen-V2", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen-V2", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors", hotload=True) + +seeds = [0] + +global_prompt = "写实摄影风格. A beautiful asia woman wearing white dress, she is holding a mirror with her right arm, with a beach background." +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts) + +global_prompt = "写实摄影风格, 细节丰富。街头一位漂亮的女孩,穿着衬衫和短裤,手持写有“实体控制”的标牌,背景是繁忙的城市街道,阳光明媚,行人匆匆。" +entity_prompts = ["一个漂亮的女孩", "标牌 '实体控制'", "短裤", "衬衫"] +example(pipe, seeds, 4, global_prompt, entity_prompts) diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py new file mode 100644 index 0000000000000000000000000000000000000000..e2cbc2e17e3d9b27950f2169825b2873c6f8b1d4 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py @@ -0,0 +1,118 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download, snapshot_download +import random + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=30, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors", hotload=True) + +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "yellow belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f5941899e24430ddede016eb8b8d9d0307d36e --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py @@ -0,0 +1,46 @@ +from PIL import Image +import torch +from modelscope import dataset_snapshot_download, snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.utils.controlnet import Annotator + +allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"] +snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern) + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +snapshot_download("DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", local_dir="models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union", allow_file_pattern="model.safetensors") +pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors", hotload=True) + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/qwen-image-context-control/image.jpg") +origin_image = Image.open("data/examples/qwen-image-context-control/image.jpg").resize((1024, 1024)) +annotator_ids = ['openpose', 'canny', 'depth', 'lineart', 'softedge', 'normal'] +for annotator_id in annotator_ids: + annotator = Annotator(processor_id=annotator_id, device="cuda") + control_image = annotator(origin_image) + control_image.save(f"{annotator.processor_id}.png") + + control_prompt = "Context_Control. " + prompt = f"{control_prompt}一个穿着淡蓝色的漂亮女孩正在翩翩起舞,背景是梦幻的星空,光影交错,细节精致。" + negative_prompt = "网格化,规则的网格,模糊, 低分辨率, 低质量, 变形, 畸形, 错误的解剖学, 变形的手, 变形的身体, 变形的脸, 变形的头发, 变形的眼睛, 变形的嘴巴" + image = pipe(prompt, seed=1, negative_prompt=negative_prompt, context_image=control_image, height=1024, width=1024) + image.save(f"image_{annotator.processor_id}.png") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py new file mode 100644 index 0000000000000000000000000000000000000000..b91d606b0125a35f90752b5a84b3ba0af4a5daf2 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py @@ -0,0 +1,134 @@ +from diffsynth.pipelines.qwen_image import ( + QwenImagePipeline, ModelConfig, + QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode +) +from diffsynth.utils.lora import merge_lora +from diffsynth import load_state_dict +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +vram_config_disk_offload = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +def demo_style(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Style.safetensors", **vram_config_disk_offload), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/1/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/style/1/0.jpg"), + Image.open("data/examples/assets/style/1/1.jpg"), + Image.open("data/examples/assets/style/1/2.jpg"), + Image.open("data/examples/assets/style/1/3.jpg"), + Image.open("data/examples/assets/style/1/4.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + save_file(lora, "model_style.safetensors") + + +def demo_coarse_fine_bias(): + # Load models + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Coarse.safetensors", **vram_config_disk_offload), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Fine.safetensors", **vram_config_disk_offload), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, + ) + + # Load images + snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/lora/3/*", + local_dir="data/examples" + ) + images = [ + Image.open("data/examples/assets/lora/3/0.jpg"), + Image.open("data/examples/assets/lora/3/1.jpg"), + Image.open("data/examples/assets/lora/3/2.jpg"), + Image.open("data/examples/assets/lora/3/3.jpg"), + Image.open("data/examples/assets/lora/3/4.jpg"), + Image.open("data/examples/assets/lora/3/5.jpg"), + ] + + # Model inference + with torch.no_grad(): + embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + lora_bias = ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-i2L", origin_file_pattern="Qwen-Image-i2L-Bias.safetensors") + lora_bias.download_if_necessary() + lora_bias = load_state_dict(lora_bias.path, torch_dtype=torch.bfloat16, device="cuda") + lora = merge_lora([lora, lora_bias]) + save_file(lora, "model_coarse_fine_bias.safetensors") + + +def generate_image(lora_path, prompt, seed): + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, + ) + pipe.load_lora(pipe.dit, lora_path) + image = pipe(prompt, seed=seed, height=1024, width=1024, num_inference_steps=50) + return image + + +demo_style() +image = generate_image("model_style.safetensors", "a cat", 0) +image.save("image_1.jpg") + +demo_coarse_fine_bias() +image = generate_image("model_coarse_fine_bias.safetensors", "bowl", 1) +image.save("image_2.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image.py new file mode 100644 index 0000000000000000000000000000000000000000..aae4a22b332493c308735b5f86432edb24dfc6f2 --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image.py @@ -0,0 +1,28 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh new file mode 100644 index 0000000000000000000000000000000000000000..e3692236746bea445d39b290e68b5b74bbc8d445 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh @@ -0,0 +1,38 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \ + --data_file_keys "image,blockwise_controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \ + --learning_rate 1e-3 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ + --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \ + --trainable_models "blockwise_controlnet" \ + --extra_inputs "blockwise_controlnet_image" \ + --use_gradient_checkpointing \ + --find_unused_parameters + +# If you want to pre-train a Blockwise ControlNet from scratch, +# please run the following script to first generate the initialized model weights file, +# and then start training with a high learning rate (1e-3). + +# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py + +# accelerate launch examples/qwen_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \ +# --data_file_keys "image,blockwise_controlnet_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ +# --model_paths '["models/blockwise_controlnet.safetensors"]' \ +# --learning_rate 1e-3 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ +# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \ +# --trainable_models "blockwise_controlnet" \ +# --extra_inputs "blockwise_controlnet_image" \ +# --use_gradient_checkpointing \ +# --find_unused_parameters \ No newline at end of file diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh new file mode 100644 index 0000000000000000000000000000000000000000..93313ec5a3d032185bd4f155ffa7a275fc7dcf10 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh @@ -0,0 +1,38 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \ + --data_file_keys "image,blockwise_controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \ + --learning_rate 1e-3 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ + --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \ + --trainable_models "blockwise_controlnet" \ + --extra_inputs "blockwise_controlnet_image" \ + --use_gradient_checkpointing \ + --find_unused_parameters + +# If you want to pre-train a Blockwise ControlNet from scratch, +# please run the following script to first generate the initialized model weights file, +# and then start training with a high learning rate (1e-3). + +# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py + +# accelerate launch examples/qwen_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \ +# --data_file_keys "image,blockwise_controlnet_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ +# --model_paths '["models/blockwise_controlnet.safetensors"]' \ +# --learning_rate 1e-3 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ +# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \ +# --trainable_models "blockwise_controlnet" \ +# --extra_inputs "blockwise_controlnet_image" \ +# --use_gradient_checkpointing \ +# --find_unused_parameters \ No newline at end of file diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh new file mode 100644 index 0000000000000000000000000000000000000000..99b25adeeb7de9fa407032239c50e46fb217326f --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh @@ -0,0 +1,38 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \ + --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \ + --learning_rate 1e-3 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ + --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \ + --trainable_models "blockwise_controlnet" \ + --extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ + --use_gradient_checkpointing \ + --find_unused_parameters + +# If you want to pre-train a Inpaint Blockwise ControlNet from scratch, +# please run the following script to first generate the initialized model weights file, +# and then start training with a high learning rate (1e-3). + +# python examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py + +# accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \ +# --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ +# --model_paths '["models/blockwise_controlnet_inpaint.safetensors"]' \ +# --learning_rate 1e-3 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \ +# --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \ +# --trainable_models "blockwise_controlnet" \ +# --extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ +# --use_gradient_checkpointing \ +# --find_unused_parameters diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh b/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh new file mode 100644 index 0000000000000000000000000000000000000000..a56fe9d2b409f3dd8e80a064b2dc63f115f32dfc --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Distill-Full_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh b/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh new file mode 100644 index 0000000000000000000000000000000000000000..7fda7b7367b5ba556a83adf645816bf2722d1292 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh @@ -0,0 +1,15 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit-2509_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh new file mode 100644 index 0000000000000000000000000000000000000000..ec257654ee3a73834e2be3c78f34907ab9218fb3 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh @@ -0,0 +1,15 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/full/Qwen-Image.sh b/examples/qwen_image/model_training/full/Qwen-Image.sh new file mode 100644 index 0000000000000000000000000000000000000000..979101e62c26e41ec408a0cbf188da5b35f3d50d --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/full/accelerate_config.yaml b/examples/qwen_image/model_training/full/accelerate_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83280f73f315a32eccb065f351d66b4b2678759d --- /dev/null +++ b/examples/qwen_image/model_training/full/accelerate_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml b/examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a75f3d91eeae160409650b482e5383ac26b297b --- /dev/null +++ b/examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: 'cpu' + offload_param_device: 'cpu' + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh new file mode 100644 index 0000000000000000000000000000000000000000..226313466dc061481dc3e1ea427ed5f2ef8a63f1 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh @@ -0,0 +1,17 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \ + --data_file_keys "image,blockwise_controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "blockwise_controlnet_image" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh new file mode 100644 index 0000000000000000000000000000000000000000..60d3ca3f8a331369092bec762f7a64ef39d5550d --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh @@ -0,0 +1,17 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \ + --data_file_keys "image,blockwise_controlnet_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "blockwise_controlnet_image" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh new file mode 100644 index 0000000000000000000000000000000000000000..853ffe2667421ffdbc949e3809907e71f0f8f9dc --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh @@ -0,0 +1,17 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \ + --data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh new file mode 100644 index 0000000000000000000000000000000000000000..79d7c376de52ffc9268be3eb7eb3ef64f5f9956d --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh @@ -0,0 +1,15 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Distill-Full_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh new file mode 100644 index 0000000000000000000000000000000000000000..061bebb0fc63885c7bcd0e9f8c18c2c7f806a9b3 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh @@ -0,0 +1,24 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_distill_qwen_image.csv \ + --data_file_keys "image" \ + --extra_inputs "seed,rand_device,num_inference_steps,cfg_scale" \ + --height 1328 \ + --width 1328 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Distill-LoRA_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task direct_distill + +# This is an experimental training feature designed to directly distill the model, enabling generation results with fewer steps to approximate those achieved with more steps. +# The model (https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) is trained using this script. +# The sample dataset is provided solely to demonstrate the dataset format. For actual usage, please construct a larger dataset using the base model. diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh new file mode 100644 index 0000000000000000000000000000000000000000..7fc0cf971569ec10102aad46db7cd5d86259d492 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh @@ -0,0 +1,18 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit-2509_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh new file mode 100644 index 0000000000000000000000000000000000000000..0662b1e38bdb078fb4ec0fa818ee7846f43dfb2d --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh @@ -0,0 +1,18 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh b/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0426714a3ade7af3ad5b2c66d996e966f8f5dd1 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh @@ -0,0 +1,18 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "data/example_image_dataset" \ + --dataset_metadata_path data/example_image_dataset/metadata_eligen.json \ + --data_file_keys "image,eligen_entity_masks" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-EliGen-Poster_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "eligen_entity_masks,eligen_entity_prompts" \ + --use_gradient_checkpointing \ + --find_unused_parameters \ + --lora_checkpoint "models/DiffSynth-Studio/Qwen-Image-EliGen-V2/model.safetensors" diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh b/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh new file mode 100644 index 0000000000000000000000000000000000000000..af861e669116df4759a9d146198c69736b732095 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh @@ -0,0 +1,17 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "data/example_image_dataset" \ + --dataset_metadata_path data/example_image_dataset/metadata_eligen.json \ + --data_file_keys "image,eligen_entity_masks" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-EliGen_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "eligen_entity_masks,eligen_entity_prompts" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh b/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh new file mode 100644 index 0000000000000000000000000000000000000000..d241ad0e33e734580d23d2759e0f1050afe5e1f2 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh @@ -0,0 +1,20 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "data/example_image_dataset" \ + --dataset_metadata_path data/example_image_dataset/metadata_qwenimage_context.csv \ + --data_file_keys "image,context_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-In-Context-Control-Union_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 64 \ + --lora_checkpoint "models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union/model.safetensors" \ + --extra_inputs "context_image" \ + --use_gradient_checkpointing \ + --find_unused_parameters + +# if you want to train from scratch, you can remove the --lora_checkpoint argument diff --git a/examples/qwen_image/model_training/lora/Qwen-Image.sh b/examples/qwen_image/model_training/lora/Qwen-Image.sh new file mode 100644 index 0000000000000000000000000000000000000000..f1198a5f676f7bee45d0726e299e16c8f8a1456b --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image.sh @@ -0,0 +1,16 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py b/examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0392f589d7739197ba4251bb1c6bfb0fcfe498 --- /dev/null +++ b/examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Initialize.py @@ -0,0 +1,13 @@ +# This script is for initializing a Qwen-Image-Blockwise-ControlNet +from diffsynth import hash_state_dict_keys +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet +import torch +from safetensors.torch import save_file + + +controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda") +controlnet.init_weight() +state_dict_controlnet = controlnet.state_dict() + +print(hash_state_dict_keys(state_dict_controlnet)) +save_file(state_dict_controlnet, "models/blockwise_controlnet.safetensors") diff --git a/examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py b/examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..83111894709eda93fbab925ca1804e2a3fd477dd --- /dev/null +++ b/examples/qwen_image/model_training/scripts/Qwen-Image-Blockwise-ControlNet-Inpaint-Initialize.py @@ -0,0 +1,12 @@ +# This script is for initializing a Inpaint Qwen-Image-ControlNet +import torch +from diffsynth import hash_state_dict_keys +from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet +from safetensors.torch import save_file + +controlnet = QwenImageBlockWiseControlNet(additional_in_dim=4).to(dtype=torch.bfloat16, device="cuda") +controlnet.init_weight() +state_dict_controlnet = controlnet.state_dict() + +print(hash_state_dict_keys(state_dict_controlnet)) +save_file(state_dict_controlnet, "models/blockwise_controlnet_inpaint.safetensors") diff --git a/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh b/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh new file mode 100644 index 0000000000000000000000000000000000000000..19191dde94494869eaba8a466598249d56ef3a28 --- /dev/null +++ b/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh @@ -0,0 +1,40 @@ +# This script is provided as an example only. +# Please manually replace the two datasets: +# the first training dataset should contain content you do not want to generate, +# and the second training dataset should contain content you do want to generate. + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-deterministic" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-differencial" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --preset_lora_path "./models/train/Qwen-Image-LoRA-deterministic/epoch-4.safetensors" \ + --preset_lora_model "dit" diff --git a/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh b/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh new file mode 100644 index 0000000000000000000000000000000000000000..133279bdf449ba1f05a9f62b297b1f453c6920e7 --- /dev/null +++ b/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh @@ -0,0 +1,17 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora_fp8" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --fp8_models "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" diff --git a/examples/qwen_image/model_training/special/fp8_training/validate.py b/examples/qwen_image/model_training/special/fp8_training/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..60783e5e91e8a2854086390c92acd647b728ea0b --- /dev/null +++ b/examples/qwen_image/model_training/special/fp8_training/validate.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image_lora_fp8/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt, seed=0) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/special/low_vram_training/Qwen-Image-LoRA.sh b/examples/qwen_image/model_training/special/low_vram_training/Qwen-Image-LoRA.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f7ed5fdb0053829bc83594ac5da099425403ca9 --- /dev/null +++ b/examples/qwen_image/model_training/special/low_vram_training/Qwen-Image-LoRA.sh @@ -0,0 +1,38 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --fp8_models "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-lowvram-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --use_gradient_checkpointing_offload \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:data_process" + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "./models/train/Qwen-Image-LoRA-lowvram-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --fp8_models "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-lowvram" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --use_gradient_checkpointing_offload \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:train" diff --git a/examples/qwen_image/model_training/special/simple/train.py b/examples/qwen_image/model_training/special/simple/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e6a112485474b9a3ff3885a4c1ba89fa746a37 --- /dev/null +++ b/examples/qwen_image/model_training/special/simple/train.py @@ -0,0 +1,76 @@ +import torch, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.diffusion import * + +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + # Load the pipeline + self.pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + ) + # Switch to training mode + self.switch_pipe_to_training_mode( + self.pipe, + lora_base_model="dit", + lora_target_modules="to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj", + lora_rank=32, + ) + + def forward(self, data): + # Preprocess + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": True, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + # Loss + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + +if __name__ == "__main__": + accelerator = accelerate.Accelerator( + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)], + ) + dataset = UnifiedDataset( + base_path="data/example_image_dataset", + metadata_path="data/example_image_dataset/metadata.csv", + repeat=50, + data_file_keys="image", + main_data_operator=UnifiedDataset.default_image_operator( + base_path="data/example_image_dataset", + height=512, + width=512, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule(accelerator.device) + model_logger = ModelLogger( + output_path="models/toy_model", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=1e-5, num_epochs=1, + ) diff --git a/examples/qwen_image/model_training/special/split_training/Qwen-Image-LoRA.sh b/examples/qwen_image/model_training/special/split_training/Qwen-Image-LoRA.sh new file mode 100644 index 0000000000000000000000000000000000000000..84e72671c884513fefc4ac579a94554c40aa759e --- /dev/null +++ b/examples/qwen_image/model_training/special/split_training/Qwen-Image-LoRA.sh @@ -0,0 +1,34 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:data_process" + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:train" diff --git a/examples/qwen_image/model_training/special/split_training/validate.py b/examples/qwen_image/model_training/special/split_training/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f9e07a9a618bbe249e63532ee58d2f74f86393 --- /dev/null +++ b/examples/qwen_image/model_training/special/split_training/validate.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-LoRA-splited/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt, seed=0) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0e4b6a4436091f6e8e2ed0e5670f4f83b67beb --- /dev/null +++ b/examples/qwen_image/model_training/train.py @@ -0,0 +1,146 @@ +import torch, os, argparse, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, processor_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) + self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "direct_distill:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "edit_image_auto_resize": True, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def qwen_image_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") + return parser + + +if __name__ == "__main__": + parser = qwen_image_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + processor_path=args.processor_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device=accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae4d5bb72444f8d7550dc1856a2a0275b224ed9 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py @@ -0,0 +1,31 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Canny_full/epoch-1.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="canny/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328)) + +prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py new file mode 100644 index 0000000000000000000000000000000000000000..18b597e1fa7508127c431aa4a466c86392a83044 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py @@ -0,0 +1,31 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Depth_full/epoch-1.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328)) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..15a15b4df8bd9608f8ce6782a309d37699b923d1 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(path="models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full/epoch-1.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +prompt = "a cat with sunglasses" +controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], + height=1024, width=1024, + num_inference_steps=40, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000000000000000000000000000000000000..07389c524c2e142679d866257785030eb15795a2 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Distill-Full_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py new file mode 100644 index 0000000000000000000000000000000000000000..9295904cacc2c00bfd0eeb801f2f10343fc89f45 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py @@ -0,0 +1,26 @@ +import torch +from PIL import Image +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Edit-2509_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) + +prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2." +images = [ + Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)), + Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)), +] +image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..c08b4850c4d20925a7d008853d11b213ee3d69b6 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py @@ -0,0 +1,23 @@ +import torch +from PIL import Image +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Edit_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) + +prompt = "将裙子改为粉色" +image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024) +image.save(f"image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image.py b/examples/qwen_image/model_training/validate_full/Qwen-Image.py new file mode 100644 index 0000000000000000000000000000000000000000..872321825eb4028fe18d5b00c8ee815bc6356e75 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("models/train/Qwen-Image_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt, seed=0) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py new file mode 100644 index 0000000000000000000000000000000000000000..4a54b5ee07316dcca50e439ec1083d392305ce18 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py @@ -0,0 +1,32 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Canny_lora/epoch-4.safetensors") + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="canny/image_1.jpg" +) +controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328)) + +prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py new file mode 100644 index 0000000000000000000000000000000000000000..626654559d1970697a2b9a12af3722581956bfbb --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py @@ -0,0 +1,33 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput +from PIL import Image +import torch +from modelscope import dataset_snapshot_download + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Depth_lora/epoch-4.safetensors") + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="depth/image_1.jpg" +) + +controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1328, 1328)) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image)] +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..60bd9f2bbbefd95ab0efe1c7e07aaeb96f4a560d --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_lora/epoch-4.safetensors") + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_image_dataset", + local_dir="./data/example_image_dataset", + allow_file_pattern="inpaint/*.jpg" +) +prompt = "a cat with sunglasses" +controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +image = pipe( + prompt, seed=0, + blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], + height=1024, width=1024, + num_inference_steps=40, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000000000000000000000000000000000000..7f644aa805d47bdc04711b28d5ea0d3d3eafc36e --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Distill-Full_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..d56b9af7f4c6e8bf6e83845c3c2511a2a9374810 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py @@ -0,0 +1,23 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Distill-LoRA_lora/epoch-4.safetensors") +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, + seed=0, + num_inference_steps=4, + cfg_scale=1, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py new file mode 100644 index 0000000000000000000000000000000000000000..e701b0749809f08db27dd52c37dc93445d0f4995 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit-2509_lora/epoch-4.safetensors") + +prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2." +images = [ + Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)), + Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)), +] +image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=40, height=1024, width=1024) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py new file mode 100644 index 0000000000000000000000000000000000000000..2576be339df5fba385c37567af0902842c306289 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py @@ -0,0 +1,21 @@ +import torch +from PIL import Image +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=None, + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Edit_lora/epoch-4.safetensors") + +prompt = "将裙子改为粉色" +image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, edit_image=image, seed=0, num_inference_steps=40, height=1024, width=1024) +image.save(f"image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py new file mode 100644 index 0000000000000000000000000000000000000000..81f697c4073b04d242761e90c5b6159a3a4e3505 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py @@ -0,0 +1,29 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen-Poster_lora/epoch-4.safetensors") + + +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + +image = pipe(global_prompt, + seed=0, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks) +image.save("Qwen-Image-EliGen-Poster.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7904e9d844232c599b7534bc58398e1820702b --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py @@ -0,0 +1,29 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen_lora/epoch-4.safetensors") + + +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + +image = pipe(global_prompt, + seed=0, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks) +image.save("Qwen-Image_EliGen.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py new file mode 100644 index 0000000000000000000000000000000000000000..83a93a3f6bf41529c9350052e03f8f95050c8eb5 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py @@ -0,0 +1,19 @@ +from PIL import Image +import torch +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-In-Context-Control-Union_lora/epoch-4.safetensors") +image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024)) +prompt = "Context_Control. a dog" +image = pipe(prompt=prompt, seed=0, context_image=image, height=1024, width=1024) +image.save("image_context.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image.py new file mode 100644 index 0000000000000000000000000000000000000000..16be2b4bda15c696c2b7bb4f3fd36176db67da0e --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt, seed=0) +image.save("image.jpg") diff --git a/examples/wanvideo/model_inference/LongCat-Video.py b/examples/wanvideo/model_inference/LongCat-Video.py new file mode 100644 index 0000000000000000000000000000000000000000..a064e91c04f9d01575af76baf8ea2af148be3d1e --- /dev/null +++ b/examples/wanvideo/model_inference/LongCat-Video.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.", + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + seed=0, tiled=True, num_frames=93, + cfg_scale=2, sigma_shift=1, +) +save_video(video, "video_1_LongCat-Video.mp4", fps=15, quality=5) + +# Video-continuation (The number of frames in `longcat_video` should be 4n+1.) +longcat_video = video[-17:] +video = pipe( + prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.", + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + seed=1, tiled=True, num_frames=93, + cfg_scale=2, sigma_shift=1, + longcat_video=longcat_video, +) +save_video(video, "video_2_LongCat-Video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Run-Egg-Statemachine-Infer.sh b/examples/wanvideo/model_inference/Run-Egg-Statemachine-Infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c18605fe3160cff021f47e9e988d403d2a8fc82 --- /dev/null +++ b/examples/wanvideo/model_inference/Run-Egg-Statemachine-Infer.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run from repo root (DiffSynth-Studio/) +cd "$(dirname "$0")/../../.." + +export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}" + +if [[ "${CONDA_DEFAULT_ENV:-}" == "diffsyn" ]]; then + PYTHON_BIN="python" +else + PYTHON_BIN="conda run -n diffsyn python" +fi + +WAN_MODEL_DIR="${WAN_MODEL_DIR:-/data/rczhang/PencilFolder/DiffSynth-Studio/models/Wan-AI/Wan2.1-T2V-1.3B}" +DATASET_DIR="${DATASET_DIR:-examples/wanvideo/model_training/egg_statemachine_dataset}" +CHECKPOINT="${CHECKPOINT:-./models/train/_egg_statemachine_instance/epoch-300.safetensors}" +OUT_DIR="${OUT_DIR:-./output/wan2.1-1.3b-statemachine-egg}" +SIGLIP_MODEL_ID="${SIGLIP_MODEL_ID:-google/siglip-so400m-patch14-384}" +SIGLIP_ORIGIN_PATTERN="${SIGLIP_ORIGIN_PATTERN:-model.safetensors}" + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" + +# Auto-detect H/W/T from dataset if not provided +if [[ -z "${HEIGHT:-}" || -z "${WIDTH:-}" || -z "${NUM_FRAMES:-}" ]]; then + eval "$($PYTHON_BIN - <<'PY' +from pathlib import Path +import json, numpy as np + +base = Path('examples/wanvideo/model_training/egg_statemachine_dataset') +meta = json.load(open(base / 'metadata.json'))[0] +mask = np.load(base / meta['instance_masks']) +if mask.ndim == 5: + _, _, T, H, W = mask.shape +else: + _, T, H, W = mask.shape +print(f'HEIGHT={H} WIDTH={W} NUM_FRAMES={T}') +PY +)" +fi + +$PYTHON_BIN examples/wanvideo/model_inference/Wan2.1-1.3b-statemachine-egg.py \ + --device "cuda" \ + --model_dir "$WAN_MODEL_DIR" \ + --dataset_dir "$DATASET_DIR" \ + --checkpoint "$CHECKPOINT" \ + --output_dir "$OUT_DIR" \ + --height "${HEIGHT:-64}" \ + --width "${WIDTH:-64}" \ + --num_frames "${NUM_FRAMES:-53}" \ + --denoising_strength "${DENOISE:-0.7}" \ + --num_inference_steps "${STEPS:-30}" \ + --seed "${SEED:-0}" \ + --fps "${FPS:-15}" \ + --quality "${QUALITY:-5}" \ + ${USE_SIGLIP_IMAGE_ENCODER:+--use_siglip_image_encoder} \ + --siglip_model_id "$SIGLIP_MODEL_ID" \ + --siglip_origin_pattern "$SIGLIP_ORIGIN_PATTERN" diff --git a/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py b/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..ec23b02df046b7c375556cb6cc9168e567904a26 --- /dev/null +++ b/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py @@ -0,0 +1,49 @@ +import torch +import PIL +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download +from typing import List + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="wanvap/*", local_dir="data/example_video_dataset") +ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4' +target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg' + +def select_frames(video_frames, num): + idx = torch.linspace(0, len(video_frames) - 1, num).long().tolist() + return [video_frames[i] for i in idx] + +image = Image.open(target_image_path).convert("RGB") +ref_video = VideoData(ref_video_path, height=480, width=832) +ref_frames = select_frames(ref_video, num=49) + +vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery." +prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent." +negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + input_image=image, + seed=42, tiled=True, + height=480, width=832, + num_frames=49, + vap_video=ref_frames, + vap_prompt=vap_prompt, + negative_vap_prompt=negative_prompt, +) +save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-1.3b-mc-lora.py b/examples/wanvideo/model_inference/Wan2.1-1.3b-mc-lora.py new file mode 100644 index 0000000000000000000000000000000000000000..a9405279a40494ac4b9f9d552c6ca941340f1c97 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-1.3b-mc-lora.py @@ -0,0 +1,118 @@ +import argparse +from pathlib import Path +import re +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +def _repo_root() -> Path: + # DiffSynth-Studio/examples/wanvideo/model_inference/ + return Path(__file__).resolve().parents[3] + +def _read_prompts(path: Path) -> list[str]: + text = path.read_text(encoding="utf-8") + prompts = [line.strip() for line in text.splitlines() if line.strip() and not line.strip().startswith("#")] + if not prompts: + raise ValueError(f"No prompts found in: {path}") + return prompts + +def _safe_name(s: str, max_len: int = 80) -> str: + s = s.strip().lower() + s = re.sub(r"[^a-z0-9]+", "-", s).strip("-") + return (s[:max_len].rstrip("-")) or "prompt" + +def main() -> None: + root_dir = _repo_root() + + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompts_file", + type=Path, + default=Path(__file__).resolve().with_name("mc_prompts_10.txt"), + ) + parser.add_argument( + "--output_dir", + type=Path, + default=(root_dir / "output" / "wan2.1-1.3b-mc-lora"), + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed_base", type=int, default=0) + parser.add_argument("--alpha", type=float, default=1.0) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=832) + parser.add_argument("--num_frames", type=int, default=80) + parser.add_argument("--num_inference_steps", type=int, default=30) + parser.add_argument("--cfg_scale", type=float, default=4.5) + parser.add_argument("--sigma_shift", type=float, default=4.0) + parser.add_argument("--fps", type=int, default=15) + parser.add_argument("--quality", type=int, default=5) + parser.add_argument("--tiled", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--tile_size_h", type=int, default=30) + parser.add_argument("--tile_size_w", type=int, default=52) + parser.add_argument("--tile_stride_h", type=int, default=15) + parser.add_argument("--tile_stride_w", type=int, default=26) + parser.add_argument("--sliding_window_size", type=int, default=32) + parser.add_argument("--sliding_window_stride", type=int, default=32) + parser.add_argument("--motion_bucket_id", type=int, default=None) + parser.add_argument( + "--negative_prompt", + type=str, + default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + ) + parser.add_argument( + "--model_dir", + type=Path, + default=(root_dir / "models" / "Wan-AI" / "Wan2.1-T2V-1.3B"), + ) + parser.add_argument( + "--lora_path", + type=Path, + default=(root_dir / "models" / "train" / "Wan2.1-1.3b-mc-lora" / "epoch-1.safetensors"), + ) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + prompts = _read_prompts(args.prompts_file) + + model_dir = args.model_dir + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=args.device, + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + ) + + pipe.load_lora(pipe.dit, str(args.lora_path), alpha=args.alpha) + + for i, prompt in enumerate(prompts, start=1): + seed = args.seed_base + (i - 1) + video = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + seed=seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + cfg_scale=args.cfg_scale, + sigma_shift=args.sigma_shift, + motion_bucket_id=args.motion_bucket_id, + tiled=args.tiled, + tile_size=(args.tile_size_h, args.tile_size_w), + tile_stride=(args.tile_stride_h, args.tile_stride_w), + sliding_window_size=args.sliding_window_size, + sliding_window_stride=args.sliding_window_stride, + ) + base = f"p{i:02d}_seed{seed}_{_safe_name(prompt)}" + mp4_path = args.output_dir / f"{base}.mp4" + txt_path = args.output_dir / f"{base}.txt" + txt_path.write_text(prompt + "\n", encoding="utf-8") + save_video(video, str(mp4_path), fps=args.fps, quality=args.quality) + +if __name__ == "__main__": + main() diff --git a/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py new file mode 100644 index 0000000000000000000000000000000000000000..a6292d920321d23905034c9f96caa7a91c35c245 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=0 +) +save_video(video, "video_slow_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=100 +) +save_video(video, "video_fast_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-1.3b-statemachine-egg.py b/examples/wanvideo/model_inference/Wan2.1-1.3b-statemachine-egg.py new file mode 100644 index 0000000000000000000000000000000000000000..24697c13b77087f5451aab493609ecdba18d9692 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-1.3b-statemachine-egg.py @@ -0,0 +1,391 @@ +import argparse +import json +import math +from pathlib import Path + +import numpy as np +import torch +from diffsynth.core import UnifiedDataset +from diffsynth.core.data.operators import ImageCropAndResize, LoadAudio, LoadTorchPickle, LoadVideo, ToAbsolutePath +from diffsynth.pipelines.wan_video_statemachine import ModelConfig, WanVideoPipeline +from diffsynth.utils.data import save_video +from safetensors.torch import load_file as load_safetensors + + +def _repo_root() -> Path: + # DiffSynth-Studio/examples/wanvideo/model_inference/ + return Path(__file__).resolve().parents[3] + + +def _infer_wan_dit_kwargs_from_loaded_dit(dit: torch.nn.Module) -> dict: + patch_size = tuple(int(x) for x in dit.patch_size) + num_heads = int(dit.blocks[0].num_heads) + num_layers = int(len(dit.blocks)) + ffn_dim = int(dit.blocks[0].ffn[0].out_features) + text_dim = int(dit.text_embedding[0].in_features) + eps = float(dit.blocks[0].norm1.eps) + out_features = int(dit.head.head.out_features) + out_dim = int(out_features // int(math.prod(patch_size))) + return { + "dim": int(dit.dim), + "in_dim": int(dit.in_dim), + "ffn_dim": ffn_dim, + "out_dim": out_dim, + "text_dim": text_dim, + "freq_dim": int(dit.freq_dim), + "eps": eps, + "patch_size": patch_size, + "num_heads": num_heads, + "num_layers": num_layers, + "has_image_input": bool(dit.has_image_input), + "has_image_pos_emb": bool(getattr(dit, "has_image_pos_emb", False)), + "has_ref_conv": bool(getattr(dit, "has_ref_conv", False)), + "add_control_adapter": bool(getattr(dit, "control_adapter", None) is not None), + "seperated_timestep": bool(getattr(dit, "seperated_timestep", False)), + "require_vae_embedding": bool(getattr(dit, "require_vae_embedding", True)), + "require_clip_embedding": bool(getattr(dit, "require_clip_embedding", True)), + "fuse_vae_embedding_in_latents": bool(getattr(dit, "fuse_vae_embedding_in_latents", False)), + } + + +def _replace_dit_with_instance_statemachine( + pipe: WanVideoPipeline, + num_class_ids: int, + num_state_ids: int, + num_instance_ids: int, +): + from diffsynth.models.wan_video_dit_instance import WanModel as InstanceWanModel + + def convert_one(dit: torch.nn.Module): + if dit is None: + return None + kwargs = _infer_wan_dit_kwargs_from_loaded_dit(dit) + kwargs.update( + num_class_ids=num_class_ids, + num_state_ids=num_state_ids, + num_instance_ids=num_instance_ids, + instance_text_dim=kwargs["text_dim"], + ) + new_dit = InstanceWanModel(**kwargs) + load_result = new_dit.load_state_dict(dit.state_dict(), strict=False) + if len(load_result.missing_keys) > 0: + print(f"[statemachine] missing keys (expected for new modules): {len(load_result.missing_keys)}") + if len(load_result.unexpected_keys) > 0: + print(f"[statemachine] unexpected keys: {len(load_result.unexpected_keys)}") + return new_dit.to(device=pipe.device, dtype=pipe.torch_dtype) + + pipe.dit = convert_one(pipe.dit) + pipe.dit2 = convert_one(pipe.dit2) + + +def _load_ids_any(base, rel): + rel = str(rel) + path = Path(base) / rel + if rel.endswith(".json"): + return torch.as_tensor(json.loads(path.read_text()), dtype=torch.long) + return torch.load(path, map_location="cpu", weights_only=False).long() + + +def _load_masks_any(base, rel): + rel = str(rel) + path = Path(base) / rel + if rel.endswith((".npy", ".npz")): + arr = np.load(path) + if isinstance(arr, np.lib.npyio.NpzFile): + if "arr_0" in arr: + arr = arr["arr_0"] + else: + arr = arr[list(arr.files)[0]] + return torch.as_tensor(arr, dtype=torch.float32) + return torch.load(path, map_location="cpu", weights_only=False).float() + + +def _ensure_instance_batch(x: torch.Tensor, kind: str) -> torch.Tensor: + if kind == "masks": + # want (B,N,T,H,W) or (B,N,L) + if x.ndim == 4: + x = x.unsqueeze(0) + elif x.ndim == 2: + x = x.unsqueeze(0) + return x + # ids: (B,N) + if x.ndim == 1: + x = x.unsqueeze(0) + return x + + +def _maybe_load_json_arg(val): + if val is None: + return None + try_path = Path(val) + raw = try_path.read_text() if try_path.exists() else val + try: + return json.loads(raw) + except Exception: + return raw + + +def _normalize_text_list(val): + if val is None: + return None + if isinstance(val, str): + return [v for v in val.split(",") if v] + return val + + +def _masks_to_bboxes(instance_masks: torch.Tensor) -> torch.Tensor: + """ + instance_masks: (B, N, T, H, W) in [0,1] + returns bboxes: (B, N, T, 4) in pixel coords xyxy + """ + if instance_masks.ndim != 5: + raise ValueError(f"instance_masks must be (B,N,T,H,W), got {tuple(instance_masks.shape)}") + b, n, t, h, w = instance_masks.shape + bboxes = torch.zeros((b, n, t, 4), dtype=torch.float32) + masks = instance_masks > 0.5 + for bi in range(b): + for ni in range(n): + for ti in range(t): + ys, xs = torch.where(masks[bi, ni, ti]) + if ys.numel() == 0: + continue + x0 = int(xs.min().item()) + x1 = int(xs.max().item()) + 1 + y0 = int(ys.min().item()) + y1 = int(ys.max().item()) + 1 + bboxes[bi, ni, ti] = torch.tensor([x0, y0, x1, y1], dtype=torch.float32) + return bboxes + + +def main() -> None: + root = _repo_root() + parser = argparse.ArgumentParser(description="Infer with Wan instance-statemachine DiT on the egg dataset.") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--height", type=int, default=64) + parser.add_argument("--width", type=int, default=64) + parser.add_argument("--num_frames", type=int, default=81) + parser.add_argument("--num_inference_steps", type=int, default=30) + parser.add_argument("--fps", type=int, default=15) + parser.add_argument("--quality", type=int, default=5) + parser.add_argument("--denoising_strength", type=float, default=0.7) + parser.add_argument( + "--dataset_dir", + type=Path, + default=(root / "examples" / "wanvideo" / "model_training" / "egg_statemachine_dataset"), + ) + parser.add_argument( + "--model_dir", + type=Path, + default=(root / "models" / "Wan-AI" / "Wan2.1-T2V-1.3B"), + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=(root / "models" / "train" / "_egg_statemachine_instance" / "epoch-0.safetensors"), + help="Train output safetensors (contains only instance modules if you used --train_instance_only).", + ) + parser.add_argument("--output_dir", type=Path, default=(root / "output" / "wan2.1-1.3b-statemachine-egg")) + parser.add_argument("--num_class_ids", type=int, default=200) + parser.add_argument("--num_state_ids", type=int, default=100) + parser.add_argument("--num_instance_ids", type=int, default=1000) + parser.add_argument("--prompt_override", type=str, default=None) + parser.add_argument( + "--reverse_state_progress", + action="store_true", + help="如果想要熟->生的反向状态序列(针对 per-frame state 数据集),会将 state_ids 在帧维上反转。", + ) + parser.add_argument( + "--state_schedule", + type=str, + default="raw_to_cooked", + choices=["raw_to_cooked", "cooked_to_raw", "keep"], + help="指定推理时的状态时间轴:raw_to_cooked(默认) / cooked_to_raw / keep(用数据集原始state_ids)。", + ) + parser.add_argument("--switch_frame", type=int, default=None, help="切换帧;默认 T//2。仅当 schedule 非 keep 时生效。") + parser.add_argument("--raw_state_id", type=int, default=1, help="未熟的 state_id") + parser.add_argument("--cooked_state_id", type=int, default=2, help="已熟的 state_id") + parser.add_argument("--use_siglip_image_encoder", action="store_true", help="使用 SigLIP 视觉编码替代默认 CLIP。") + parser.add_argument("--siglip_model_id", type=str, default="google/siglip-so400m-patch14-384", help="SigLIP 模型 ID。") + parser.add_argument("--siglip_origin_pattern", type=str, default="model.safetensors", help="SigLIP 权重文件匹配模式。") + parser.add_argument("--mask_shift_per_frame", type=int, default=0, help="正值表示 mask 每帧向上平移的像素数(累加)。") + parser.add_argument("--instance_class_text", type=str, default=None, help="JSON 列表或逗号分隔的 instance class 文本。") + parser.add_argument("--instance_state_texts", type=str, default=None, help="JSON 列表的列表或文件路径,用于多状态文本。") + parser.add_argument("--instance_state_weights", type=str, default=None, help="JSON 列表的列表或文件路径,对应多状态权重。") + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + dataset = UnifiedDataset( + base_path=str(args.dataset_dir), + metadata_path=str(args.dataset_dir / "metadata.json"), + repeat=1, + data_file_keys=("video", "instance_class_ids", "instance_state_ids", "instance_ids", "instance_masks"), + main_data_operator=UnifiedDataset.default_video_operator( + base_path=str(args.dataset_dir), + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + num_frames=args.num_frames, + time_division_factor=4, + time_division_remainder=1, + ), + special_operator_map={ + "animate_face_video": ToAbsolutePath(str(args.dataset_dir)) >> LoadVideo( + args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16) + ), + "input_audio": ToAbsolutePath(str(args.dataset_dir)) >> LoadAudio(sr=16000), + "instance_class_ids": lambda p: _load_ids_any(args.dataset_dir, p), + "instance_state_ids": lambda p: _load_ids_any(args.dataset_dir, p), + "instance_ids": lambda p: _load_ids_any(args.dataset_dir, p), + "instance_masks": lambda p: _load_masks_any(args.dataset_dir, p), + }, + ) + sample = dataset[0] + + prompt = sample.get("prompt", "an egg") + if args.prompt_override is not None: + prompt = args.prompt_override + + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=args.device, + model_configs=[ + ModelConfig(path=str(args.model_dir / "diffusion_pytorch_model.safetensors")), + ModelConfig(path=str(args.model_dir / "models_t5_umt5-xxl-enc-bf16.pth")), + ModelConfig(path=str(args.model_dir / "Wan2.1_VAE.pth")), + *( + [ModelConfig(model_id=args.siglip_model_id, origin_file_pattern=args.siglip_origin_pattern, model_name="siglip2_image_encoder")] + if args.use_siglip_image_encoder else [] + ), + ], + tokenizer_config=ModelConfig(path=str(args.model_dir / "google" / "umt5-xxl")), + use_siglip_image_encoder=args.use_siglip_image_encoder, + ) + + _replace_dit_with_instance_statemachine( + pipe, + num_class_ids=args.num_class_ids, + num_state_ids=args.num_state_ids, + num_instance_ids=args.num_instance_ids, + ) + + if args.checkpoint.exists(): + sd = load_safetensors(str(args.checkpoint), device="cpu") + load_result = pipe.dit.load_state_dict(sd, strict=False) + print(f"[ckpt] loaded: {args.checkpoint}") + if len(load_result.missing_keys) > 0: + print(f"[ckpt] missing keys: {len(load_result.missing_keys)}") + if len(load_result.unexpected_keys) > 0: + print(f"[ckpt] unexpected keys: {len(load_result.unexpected_keys)}") + else: + print(f"[ckpt] not found, using random instance modules: {args.checkpoint}") + + instance_class_ids = _ensure_instance_batch(torch.as_tensor(sample["instance_class_ids"]).long(), "ids") + instance_state_ids = _ensure_instance_batch(torch.as_tensor(sample["instance_state_ids"]).long(), "ids") + instance_ids = _ensure_instance_batch(torch.as_tensor(sample["instance_ids"]).long(), "ids") + instance_masks = _ensure_instance_batch(torch.as_tensor(sample["instance_masks"]).float(), "masks") + + T = instance_masks.shape[2] if instance_masks.ndim == 5 else instance_masks.shape[1] + + if args.state_schedule != "keep": + switch_f = args.switch_frame if args.switch_frame is not None else T // 2 + switch_f = max(0, min(T, int(switch_f))) + if args.state_schedule == "raw_to_cooked": + seq = [args.raw_state_id] * switch_f + [args.cooked_state_id] * (T - switch_f) + else: # cooked_to_raw + seq = [args.cooked_state_id] * switch_f + [args.raw_state_id] * (T - switch_f) + seq = torch.tensor(seq, dtype=torch.long).unsqueeze(0) + instance_state_ids = seq + # 若是 per-frame diag mask,N==T;否则广播到 N tokens + if instance_masks.ndim == 5 and instance_masks.shape[1] == T: + instance_state_ids = instance_state_ids + else: + instance_state_ids = instance_state_ids.expand(instance_ids.shape[0], instance_ids.shape[1]) + + if args.mask_shift_per_frame != 0 and instance_masks.ndim == 5: + shift = args.mask_shift_per_frame + b, n, t, h, w = instance_masks.shape + shifted = torch.zeros_like(instance_masks) + for i in range(t): + dy = shift * i + src = instance_masks[:, :, i] + if abs(dy) >= h: + continue # 全部移出画面,保持0 + if dy > 0: + # 向上平移:把下方补0 + shifted[:, :, i, : h - dy, :] = src[:, :, dy:, :] + elif dy < 0: + dy = -dy + shifted[:, :, i, dy:, :] = src[:, :, : h - dy, :] + else: + shifted[:, :, i] = src + instance_masks = shifted + + instance_class_text = _normalize_text_list(_maybe_load_json_arg(args.instance_class_text)) + if instance_class_text is None: + instance_class_text = ["egg"] * int(instance_ids.shape[1]) + if isinstance(instance_class_text, list) and len(instance_class_text) == 1 and int(instance_ids.shape[1]) > 1: + instance_class_text = instance_class_text * int(instance_ids.shape[1]) + + instance_state_texts = _maybe_load_json_arg(args.instance_state_texts) + if instance_state_texts is None: + instance_state_texts = [["raw", "cooked"] for _ in range(int(instance_ids.shape[1]))] + + # weights per frame from instance_state_ids (assumes 2 states matching raw/cooked) + if args.instance_state_weights is not None: + instance_state_weights = torch.as_tensor(_maybe_load_json_arg(args.instance_state_weights), dtype=torch.float32) + if instance_state_weights.ndim == 3: + instance_state_weights = instance_state_weights.unsqueeze(0) + else: + if len(instance_state_texts[0]) != 2: + raise ValueError("This example auto-builds weights from instance_state_ids and requires exactly 2 states in instance_state_texts.") + b, n, t = instance_state_ids.shape + instance_state_weights = torch.zeros((b, n, t, 2), dtype=torch.float32) + instance_state_weights[..., 0] = (instance_state_ids == args.raw_state_id).to(torch.float32) + instance_state_weights[..., 1] = (instance_state_ids == args.cooked_state_id).to(torch.float32) + + instance_bboxes = _masks_to_bboxes(instance_masks) + + video = pipe( + prompt=prompt, + input_video=sample["video"], + denoising_strength=args.denoising_strength, + seed=args.seed, + height=sample["video"][0].size[1], + width=sample["video"][0].size[0], + num_frames=len(sample["video"]), + num_inference_steps=args.num_inference_steps, + tiled=True, + instance_ids=instance_ids.to(device=args.device), + instance_class_text=instance_class_text, + instance_state_texts=instance_state_texts, + instance_state_weights=instance_state_weights.to(device=args.device), + instance_bboxes=instance_bboxes.to(device=args.device), + ) + + out_path = args.output_dir / "egg_statemachine_infer.mp4" + save_video(video, str(out_path), fps=args.fps, quality=args.quality) + print(f"[ok] wrote: {out_path}") + + +if __name__ == "__main__": + main() + + + + + + +# conda run -n diffsyn python examples/wanvideo/model_inference/Wan2.1-1.3b-statemachine-egg.py \ +# --device cuda \ +# --model_dir models/Wan-AI/Wan2.1-T2V-1.3B \ +# --dataset_dir examples/wanvideo/model_training/egg_statemachine_dataset \ +# --checkpoint models/train/_egg_statemachine_instance/epoch-377.safetensors \ +# --output_dir output/wan2.1-1.3b-statemachine-egg_moveup_long \ +# --height 256 --width 448 --num_frames 85 \ +# --state_schedule cooked_to_raw --switch_frame 20 \ +# --mask_shift_per_frame 10 \ +# --prompt_override "an egg running upward and drifting, cooling back to raw" \ +# --num_inference_steps 28 --denoising_strength 0.6 diff --git a/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9a899e807b8247c1da0c953466e0fd0fad93ad --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] +) + +# First and last frame to video +video = pipe( + prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), + end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), + seed=0, tiled=True, + height=960, width=960, num_frames=33, + sigma_shift=16, +) +save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..748fba780e4eaf6fdaa886de3f263077ebea4149 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..39324bad0dc5c3ec7007498f025433e0b0b917bb --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..b8334fddee746bb36aa5c6ee7e566459bd45a5c4 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..fe16080400ecb4b07ceb4ad3efd09d984b8643f0 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py new file mode 100644 index 0000000000000000000000000000000000000000..7babad1901eb3551dff31735d6aa598a8cc88ed6 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..acfedaea41d6d2546b49b95c8ee1124303d27734 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..1d14badbc6d5aab9d821e4dc6deb29e83bff356b --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py new file mode 100644 index 0000000000000000000000000000000000000000..019b1dac57343c506d14082e02889201465e8faa --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..45421d99a49f4ef9ab056bca257b0aa702c4319b --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a87352c54b9aa23627afa89e2a87b5164c88b3 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py @@ -0,0 +1,36 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d22cebec198848b26933fd70019dfa3d5db5fc --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000000000000000000000000000000000000..3220433cf64cecf3d2f8cb3e8b1b3e4cb054b9c7 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,35 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + height=720, width=1280, +) +save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B-Statemachine-Instance-Infer.sh b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B-Statemachine-Instance-Infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..23916be4baf3e548d6861b2a78ba0023f4f166c9 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B-Statemachine-Instance-Infer.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run from repo root (DiffSynth-Studio/) +cd "$(dirname "$0")/../../.." + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" + +# ---------------------------- +# 1) Base model files (edit these) +# ---------------------------- +WAN_MODEL_DIR="${WAN_MODEL_DIR:-/data/rczhang/PencilFolder/DiffSynth-Studio/models/Wan-AI/Wan2.1-T2V-1.3B}" +DIFFUSION_CKPT="${DIFFUSION_CKPT:-$WAN_MODEL_DIR/diffusion_pytorch_model.safetensors}" +TEXT_ENCODER_CKPT="${TEXT_ENCODER_CKPT:-$WAN_MODEL_DIR/models_t5_umt5-xxl-enc-bf16.pth}" +VAE_CKPT="${VAE_CKPT:-$WAN_MODEL_DIR/Wan2.1_VAE.pth}" +TOKENIZER_PATH="${TOKENIZER_PATH:-$WAN_MODEL_DIR/google/umt5-xxl/}" + +# ---------------------------- +# 2) Optional trained checkpoint (instance modules) +# ---------------------------- +# 训练脚本 `train_statemachine_instance.py` 会产出 `epoch-*.safetensors`(只包含 trainable 参数)。 +INSTANCE_CKPT="${INSTANCE_CKPT:-./models/train/Wan2.1-T2V-1.3B_statemachine_instance/epoch-0.safetensors}" + +# ---------------------------- +# 3) Optional instance inputs (torch .pth) +# ---------------------------- +# 这些都不是必须;没提供就走普通 T2V/I2V/V2V 推理(不加实例控制)。 +INSTANCE_CLASS_IDS_PTH="${INSTANCE_CLASS_IDS_PTH:-}" +INSTANCE_STATE_IDS_PTH="${INSTANCE_STATE_IDS_PTH:-}" +INSTANCE_IDS_PTH="${INSTANCE_IDS_PTH:-}" +INSTANCE_MASKS_PTH="${INSTANCE_MASKS_PTH:-}" + +# ---------------------------- +# 4) Prompt & output +# ---------------------------- +PROMPT="${PROMPT:-纪实摄影风格画面,一只活泼的小狗在草地上奔跑。}" +NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}" +OUT_MP4="${OUT_MP4:-./outputs/Wan2.1-T2V-1.3B_statemachine_instance.mp4}" +SEED="${SEED:-0}" + +python - <<'PY' +import os +from pathlib import Path + +import torch + +from diffsynth.core import ModelConfig, load_state_dict +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video_statemachine import WanVideoPipeline + +device = "cuda" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.bfloat16 + +wan_model_dir = os.environ["WAN_MODEL_DIR"] +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch_dtype, + device=device, + model_configs=[ + ModelConfig(path=os.environ["DIFFUSION_CKPT"]), + ModelConfig(path=os.environ["TEXT_ENCODER_CKPT"]), + ModelConfig(path=os.environ["VAE_CKPT"]), + ], + tokenizer_config=ModelConfig(path=os.environ["TOKENIZER_PATH"]), +) + +# 1) Replace base DiT -> instance DiT (text+bbox path) +from diffsynth.models.wan_video_dit_instance import WanModel as InstanceWanModel +from examples.wanvideo.model_training.train_statemachine_instance import _infer_wan_dit_kwargs_from_loaded_dit + +kwargs = _infer_wan_dit_kwargs_from_loaded_dit(pipe.dit) +kwargs.update( + num_class_ids=int(os.environ.get("NUM_CLASS_IDS", "200")), + num_state_ids=int(os.environ.get("NUM_STATE_IDS", "100")), + num_instance_ids=int(os.environ.get("NUM_INSTANCE_IDS", "1000")), +) +new_dit = InstanceWanModel(**kwargs).to(device=device, dtype=torch_dtype) +new_dit.load_state_dict(pipe.dit.state_dict(), strict=False) +pipe.dit = new_dit + +# 2) Load trained instance checkpoint (optional) +instance_ckpt = os.environ.get("INSTANCE_CKPT", "") +if instance_ckpt and Path(instance_ckpt).exists(): + sd = load_state_dict(instance_ckpt, torch_dtype=torch_dtype, device=device) + missing, unexpected = pipe.dit.load_state_dict(sd, strict=False) + print(f"[load] instance_ckpt={instance_ckpt} missing={len(missing)} unexpected={len(unexpected)}") +else: + print(f"[load] skip instance_ckpt (not found): {instance_ckpt}") + + +def load_pth(path: str, dtype: torch.dtype, device: str): + if not path: + return None + t = torch.load(path, map_location="cpu") + if not isinstance(t, torch.Tensor): + t = torch.as_tensor(t) + return t.to(device=device, dtype=dtype) + + +def parse_text_list(env_key): + v = os.environ.get(env_key, "") + if not v: + return None + return [s.strip() for s in v.split(",") if s.strip()] + +instance_class_text = parse_text_list("INSTANCE_CLASS_TEXTS") # e.g., "cat,dog" +instance_state_texts = parse_text_list("INSTANCE_STATE_TEXTS") # e.g., "sleeping,running,walking" (states for one instance) + +instance_bboxes = load_pth(os.environ.get("INSTANCE_BBOXES_PTH", ""), torch.float32, device) +instance_ids = load_pth(os.environ.get("INSTANCE_IDS_PTH", ""), torch.long, device) +instance_state_weights = load_pth(os.environ.get("INSTANCE_STATE_WEIGHTS_PTH", ""), torch.float32, device) + +if instance_state_texts is not None and (not isinstance(instance_state_texts, list) or len(instance_state_texts) == 0): + raise ValueError("INSTANCE_STATE_TEXTS must be a non-empty comma-separated list.") +if instance_state_texts is not None and instance_bboxes is None: + raise ValueError("INSTANCE_BBOXES_PTH is required when using instance control.") +if instance_state_texts is not None and instance_ids is None: + raise ValueError("INSTANCE_IDS_PTH is required when using instance control.") + +if instance_state_texts is not None: + # Wrap as nested list: (N=1, S) + instance_state_texts = [instance_state_texts] + if instance_state_weights is None: + # Default: always use state 0 for all frames. + F = int(instance_bboxes.shape[2]) if instance_bboxes.ndim == 4 else 1 + S = int(len(instance_state_texts[0])) + instance_state_weights = torch.zeros((1, 1, F, S), device=device, dtype=torch.float32) + instance_state_weights[:, :, :, 0] = 1.0 + +video = pipe( + prompt=os.environ["PROMPT"], + negative_prompt=os.environ.get("NEGATIVE_PROMPT", ""), + seed=int(os.environ.get("SEED", "0")), + tiled=True, + instance_class_text=instance_class_text, + instance_state_texts=instance_state_texts, + instance_state_weights=instance_state_weights, + instance_bboxes=instance_bboxes, + instance_ids=instance_ids, +) + +out_mp4 = Path(os.environ["OUT_MP4"]) +out_mp4.parent.mkdir(parents=True, exist_ok=True) +save_video(video, str(out_mp4), fps=15, quality=5) +print(f"[ok] saved: {out_mp4}") +PY diff --git a/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py new file mode 100644 index 0000000000000000000000000000000000000000..747fe4e1896165f63807df4306e5521b13790cf9 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py @@ -0,0 +1,34 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video_1_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) + +# Video-to-video +video = VideoData("video_1_Wan2.1-T2V-1.3B.mp4", height=480, width=832) +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_video=video, denoising_strength=0.7, + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..018571b84b10c2a0b22e3a54869a306fb83c75d0 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py new file mode 100644 index 0000000000000000000000000000000000000000..aed089d6190ecdd645c19b4dafc1731464e6e3a7 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py @@ -0,0 +1,52 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py new file mode 100644 index 0000000000000000000000000000000000000000..6e46d984e0d2902234ab7b33667fd44e2ec0b6d8 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py @@ -0,0 +1,53 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py b/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f8474b90c662eb6c8e882700228061694efb53 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.1-VACE-14B.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-VACE-14B.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.1-VACE-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py b/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..d435b688ffc760f8606ed9380952f6866d31cde6 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py @@ -0,0 +1,62 @@ +import torch +from PIL import Image +from diffsynth.core import load_state_dict +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download, snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/wan/animate/*", +) + +# Animate +input_image = Image.open("data/examples/wan/animate/animate_input_image.png") +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + num_frames=81, height=720, width=1280, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5) + +# Replace +snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B") +lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.bfloat16, device="cuda")["state_dict"] +pipe.load_lora(pipe.dit, state_dict=lora_state_dict) +input_image = Image.open("data/examples/wan/animate/replace_input_image.png") +animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4] +animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4] +animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + animate_inpaint_video=animate_inpaint_video, + animate_mask_video=animate_mask_video, + num_frames=81, height=720, width=1280, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video_2_Wan2.2-Animate-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py new file mode 100644 index 0000000000000000000000000000000000000000..58e3b426dca4126a8d8b0c9ed90e73c0e290d59e --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py @@ -0,0 +1,43 @@ +import torch +from diffsynth.utils.data import save_video,VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..e9505804e5ca15b46fecb630f0c715afc0642334 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.utils.data import save_video,VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-Fun-A14B-Control.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..4458de1dbc5e858d4cd44e21f4ebb5d3b2181b0a --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.2-Fun-A14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000000000000000000000000000000000000..b95d6112294ba3176c33d3cd5f964afcf02b2293 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) + +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + switch_DiT_boundary=0.9, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..140c6a0d6d062040c275f45e7d81e3153912d0c9 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py @@ -0,0 +1,73 @@ +# This script can generate a single video clip. +# If you need generate long videos, please refer to `Wan2.2-S2V-14B_multi_clips.py`. +import torch +from PIL import Image +import librosa +from diffsynth.utils.data import VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), +) +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*" +) + +num_frames = 81 # 4n+1 +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) +# s2v audio input, recommend 16kHz sampling rate +audio_path = 'data/example_video_dataset/wans2v/sing.MP3' +input_audio, sample_rate = librosa.load(audio_path, sr=16000) + +# Speech-to-video +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + num_inference_steps=40, +) +save_video_with_audio(video[1:], "video_1_Wan2.2-S2V-14B.mp4", audio_path, fps=16, quality=5) + +# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. +pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' +pose_video = VideoData(pose_video_path, height=height, width=width) + +# Speech-to-video with pose +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + s2v_pose_video=pose_video, + num_inference_steps=40, +) +save_video_with_audio(video[1:], "video_2_Wan2.2-S2V-14B.mp4", audio_path, fps=16, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py new file mode 100644 index 0000000000000000000000000000000000000000..35d42bad4e60fb2c49b9bd1bf687909fdd41e962 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py @@ -0,0 +1,117 @@ +import torch +from PIL import Image +import librosa +from diffsynth.utils.data import VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V +from modelscope import dataset_snapshot_download + + +def speech_to_video( + prompt, + input_image, + audio_path, + negative_prompt="", + num_clip=None, + audio_sample_rate=16000, + pose_video_path=None, + infer_frames=80, + height=448, + width=832, + num_inference_steps=40, + fps=16, # recommend fixing fps as 16 for s2v + motion_frames=73, # hyperparameter of wan2.2-s2v + save_path=None, +): + # s2v audio input, recommend 16kHz sampling rate + input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate) + # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. + pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None + + audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( + pipe=pipe, + input_audio=input_audio, + audio_sample_rate=sample_rate, + s2v_pose_video=pose_video, + num_frames=infer_frames + 1, + height=height, + width=width, + fps=fps, + ) + num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat + print(f"Generating {num_repeat} video clips...") + motion_videos = [] + video = [] + for r in range(num_repeat): + s2v_pose_latents = pose_latents[r] if pose_latents is not None else None + current_clip = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=infer_frames + 1, + height=height, + width=width, + audio_embeds=audio_embeds[r], + s2v_pose_latents=s2v_pose_latents, + motion_video=motion_videos, + num_inference_steps=num_inference_steps, + ) + current_clip = current_clip[-infer_frames:] + if r == 0: + current_clip = current_clip[3:] + overlap_frames_num = min(motion_frames, len(current_clip)) + motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:] + video.extend(current_clip) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) + print(f"processed the {r+1}th clip of total {num_repeat} clips.") + return video + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*", +) + +infer_frames = 80 # 4n +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) + +video_with_audio = speech_to_video( + prompt=prompt, + input_image=input_image, + audio_path='data/example_video_dataset/wans2v/sing.MP3', + negative_prompt=negative_prompt, + pose_video_path='data/example_video_dataset/wans2v/pose.mp4', + save_path="video_full_Wan2.2-S2V-14B.mp4", + infer_frames=infer_frames, + height=height, + width=width, +) +# num_clip means generating only the first n clips with n * infer_frames frames. +video_with_audio_pose = speech_to_video( + prompt=prompt, + input_image=input_image, + audio_path='data/example_video_dataset/wans2v/sing.MP3', + negative_prompt=negative_prompt, + pose_video_path='data/example_video_dataset/wans2v/pose.mp4', + save_path="video_clip_2_Wan2.2-S2V-14B.mp4", + num_clip=2 +) diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000000000000000000000000000000000000..9bef43217574f751c01cb4c4de5e602bc113a3de --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000000000000000000000000000000000000..ca968d599e42f66e94bb0e66ca925f45d9514a14 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -0,0 +1,43 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1248, + num_frames=121, +) +save_video(video, "video_1_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) + +# Image-to-video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1248, + input_image=input_image, + num_frames=121, +) +save_video(video, "video_2_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py new file mode 100644 index 0000000000000000000000000000000000000000..10566fa1a747be3941614f259857c9bdfe0d37f6 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py @@ -0,0 +1,68 @@ +# Without VRAM Management, 80G VRAM is not enough to run this example. +# We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`. +# CPU Offload is enabled in this example. +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/comp_attn_pipeline.py b/examples/wanvideo/model_inference/comp_attn_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f434697c144665c81c15d24f87caea327fb8d6da --- /dev/null +++ b/examples/wanvideo/model_inference/comp_attn_pipeline.py @@ -0,0 +1,251 @@ +""" +Comp-Attn Pipeline 示例 +======================== + +演示如何使用 Comp-Attn 生成多主体组合视频。 + +Prompt 和 BBox 的绑定机制: +========================= +使用 Python 变量拼接 prompt,变量顺序与 bboxes 列表顺序对应: + + subject0 = "red car" + subject1 = "blue bicycle" + + prompt = f"A {subject0} drives left, a {subject1} rides right" + + bboxes = [ + car_bboxes, # 对应 subject0 + bike_bboxes, # 对应 subject1 + ] + + subjects = [subject0, subject1] # 顺序与 bboxes 一致 +""" + +import torch +from PIL import Image, ImageDraw, ImageFont + +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video_comp_attn import WanVideoCompAttnPipeline +from diffsynth.pipelines.wan_video import ModelConfig +from diffsynth.models.comp_attn_model import CompAttnConfig + +torch.cuda.set_device(1) + +# ============================================================================= +# 视频参数 +# ============================================================================= +height, width, num_frames = 480, 832, 81 + +# ============================================================================= +# 定义主体(变量顺序决定与 bboxes 的绑定关系) +# ============================================================================= +subject0 = "blue man" +subject1 = "red woman" +subject2 = "blue woman" +subject3 = "red man" + +# subjects 列表,顺序与 bboxes 一一对应 +subjects = [subject0, subject1, subject2, subject3] + +# ============================================================================= +# Per-frame state control (running -> idle). Shape: (M, F, S) +# ============================================================================= +state_texts = [ + ["running", "idle"], + ["running", "idle"], + ["moving", "stopped"], + ["running", "idle"], +] +half = num_frames // 2 +state_weights = [] +for _ in subjects: + weights = [] + for t in range(num_frames): + if t < half: + weights.append([1.0, 0.0]) + else: + weights.append([0.0, 1.0]) + state_weights.append(weights) + +# ============================================================================= +# 运动轨迹定义 - 4 个关键帧 +# ============================================================================= +def create_moving_bbox(start_x, end_x, y_center, box_width, box_height, num_keyframes=80): + """创建从 start_x 移动到 end_x 的关键帧 bbox 序列""" + keyframes = [] + for i in range(num_keyframes): + progress = i / (num_keyframes - 1) + center_x = start_x + (end_x - start_x) * progress + left = center_x - box_width / 2 + right = center_x + box_width / 2 + top = y_center - box_height / 2 + bottom = y_center + box_height / 2 + keyframes.append((left, top, right, bottom)) + return keyframes + +# subject0 (red car) 的运动轨迹:从左往右 +bbox0 = create_moving_bbox( + start_x=50, end_x=350, + y_center=height * 0.60, + box_width=100, + box_height=60, +) + +# subject1 (blue bicycle) 的运动轨迹:从中间往右 +bbox1 = create_moving_bbox( + start_x=300, end_x=600, + y_center=height * 0.75, + box_width=50, + box_height=80, +) + +# subject2 (yellow bus) 的运动轨迹:从右往左(反向) +bbox2 = create_moving_bbox( + start_x=750, end_x=450, + y_center=height * 0.55, + box_width=140, + box_height=80, +) + +# subject3 (green motorcycle) 的运动轨迹:从左下往右上 +bbox3 = create_moving_bbox( + start_x=100, end_x=700, + y_center=height * 0.80, + box_width=60, + box_height=50, +) + +# bboxes 列表,顺序与 subjects 一一对应 +bboxes = [bbox0, bbox1, bbox2, bbox3] + +# ============================================================================= +# 使用变量拼接 Prompt +# ============================================================================= +prompt = ( + f"A {subject0} walks forward, " + f"a {subject1} walks alongside, " + f"a {subject2} passes by in the opposite direction, " + f"and a {subject3} walks through, " + f"busy daytime urban street scene." +) + +print("=" * 60) +print("变量绑定关系(4个物体):") +print("=" * 60) +for i, (subj, bbox) in enumerate(zip(subjects, bboxes)): + start_x = (bbox[0][0] + bbox[0][2]) / 2 + end_x = (bbox[-1][0] + bbox[-1][2]) / 2 + direction = "→" if end_x > start_x else "←" + print(f" subject{i} = \"{subj}\" -> bbox{i} ({direction} 移动)") +print(f"\nPrompt:\n {prompt}") +print("=" * 60) + +# ============================================================================= +# 可视化 +# ============================================================================= +def draw_trajectory_visualization(width, height, subjects, trajectories, colors=None): + if colors is None: + colors = ["red", "blue", "orange", "green", "purple", "cyan"] + + img = Image.new("RGB", (width, height), color="white") + draw = ImageDraw.Draw(img) + + grid_spacing = 50 + for x in range(0, width, grid_spacing): + draw.line([(x, 0), (x, height)], fill="lightgray", width=1) + for y in range(0, height, grid_spacing): + draw.line([(0, y), (width, y)], fill="lightgray", width=1) + + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12) + title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 18) + except: + font = ImageFont.load_default() + title_font = font + + for subj_idx, (subject, trajectory) in enumerate(zip(subjects, trajectories)): + color = colors[subj_idx % len(colors)] + + centers = [(b[0] + b[2]) / 2 for b in trajectory], [(b[1] + b[3]) / 2 for b in trajectory] + centers = list(zip(centers[0], centers[1])) + + # 绘制轨迹线 + for i in range(len(centers) - 1): + draw.line([centers[i], centers[i + 1]], fill=color, width=2) + + # 只显示首尾 bbox(避免太密集) + for kf_idx in [0, len(trajectory) - 1]: + bbox = trajectory[kf_idx] + x1, y1, x2, y2 = bbox + line_width = 3 if kf_idx == len(trajectory) - 1 else 2 + draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width) + + draw.text((centers[0][0] - 15, centers[0][1] - 20), "S", fill=color, font=font) + draw.text((centers[-1][0] - 10, centers[-1][1] + 5), "E", fill=color, font=font) + + # 显示变量名 + label = f"{subj_idx}: {subject}" + draw.text((trajectory[0][0], trajectory[0][1] - 30), label, fill=color, font=font) + + draw.text((10, 10), f"Comp-Attn Trajectory ({width}x{height})", fill="black", font=title_font) + + legend_y = 40 + for i, subject in enumerate(subjects): + color = colors[i % len(colors)] + draw.rectangle([10, legend_y + i * 20, 25, legend_y + i * 20 + 15], fill=color) + draw.text((30, legend_y + i * 20), f"subject{i}: {subject}", fill="black", font=font) + + return img + +trajectory_viz = draw_trajectory_visualization(width, height, subjects, bboxes) +trajectory_viz.save("comp_attn_trajectory.png") +print("\n轨迹可视化已保存到 comp_attn_trajectory.png") + +# ============================================================================= +# 创建 Pipeline +# ============================================================================= +import glob +model_dir = "/data/rczhang/PencilFolder/DiffSynth-Studio/models/Wan-AI/Wan2.1-T2V-14B" +pipe = WanVideoCompAttnPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda:0", + model_configs=[ + ModelConfig(path=sorted(glob.glob(f"{model_dir}/diffusion_pytorch_model*.safetensors"))), + ModelConfig(path=f"{model_dir}/models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(path=f"{model_dir}/Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(path=f"{model_dir}/google/umt5-xxl/"), +) + +# ============================================================================= +# Comp-Attn 配置 +# ============================================================================= +comp_attn = CompAttnConfig( + subjects=subjects, # [subject0, subject1, subject2, subject3] + bboxes=bboxes, # [bbox0, bbox1, bbox2, bbox3] + enable_sci=True, + enable_lam=True, + interpolate=False, # 已经是80帧,不需要插值 + temperature=0.2, + state_texts=state_texts, + state_weights=state_weights, + state_scale=1.0, +) + +# ============================================================================= +# 生成视频 +# ============================================================================= +print("\n开始生成视频...") +video = pipe( + prompt=prompt, + negative_prompt="low quality, blurry, distorted, duplicate subjects, static", + seed=42, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + comp_attn=comp_attn, +) + +save_video(video, "video_comp_attn_pipeline.mp4", fps=15, quality=5) +print("视频已保存到 video_comp_attn_pipeline.mp4") diff --git a/examples/wanvideo/model_inference/comp_attn_standalone.py b/examples/wanvideo/model_inference/comp_attn_standalone.py new file mode 100644 index 0000000000000000000000000000000000000000..c547db4a527dbb15f4aba85f87f33f097e2b43be --- /dev/null +++ b/examples/wanvideo/model_inference/comp_attn_standalone.py @@ -0,0 +1,49 @@ +import torch + +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from diffsynth_ext.comp_attn import CompAttnConfig, CompAttnPipelineWrapper + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +pipe = CompAttnPipelineWrapper(pipe) + +height, width, num_frames = 480, 832, 81 +subjects = ["red sedan", "blue bicycle"] + +left_box = (0.05 * width, 0.40 * height, 0.45 * width, 0.85 * height) +right_box = (0.55 * width, 0.40 * height, 0.95 * width, 0.85 * height) + +comp_attn = CompAttnConfig( + subjects=subjects, + bboxes=[ + [left_box] * 4, + [right_box] * 4, + ], + enable_sci=True, + enable_lam=True, + interpolate=True, +) + +video = pipe( + prompt="A red sedan drives on the left while a blue bicycle follows on the right, daytime street scene.", + negative_prompt="low quality, blurry, distorted, duplicate subjects", + seed=0, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + comp_attn=comp_attn, +) + +save_video(video, "video_comp_attn.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/example.py b/examples/wanvideo/model_inference/example.py new file mode 100644 index 0000000000000000000000000000000000000000..719ae168e3e2ffab2b5c93532e83d07108c7cd3c --- /dev/null +++ b/examples/wanvideo/model_inference/example.py @@ -0,0 +1,32 @@ +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_boat_seagull.json b/examples/wanvideo/model_inference/example_boat_seagull.json new file mode 100644 index 0000000000000000000000000000000000000000..69c434306849ea622efef697660ff30d056b7ede --- /dev/null +++ b/examples/wanvideo/model_inference/example_boat_seagull.json @@ -0,0 +1 @@ +[[[60, 240, 360, 440], [600, 60, 680, 120]], [[80, 240, 380, 440], [577, 59, 657, 119]], [[100, 240, 400, 440], [554, 58, 634, 118]], [[120, 240, 420, 440], [531, 57, 611, 117]], [[140, 240, 440, 440], [508, 56, 588, 116]], [[160, 240, 460, 440], [485, 55, 565, 115]], [[180, 240, 480, 440], [462, 54, 542, 114]], [[200, 240, 500, 440], [439, 53, 519, 113]], [[220, 240, 520, 440], [416, 52, 496, 112]], [[240, 240, 540, 440], [393, 51, 473, 111]], [[260, 240, 560, 440], [370, 50, 450, 110]], [[280, 240, 580, 440], [347, 49, 427, 109]], [[300, 240, 600, 440], [324, 48, 404, 108]], [[320, 240, 620, 440], [301, 47, 381, 107]], [[340, 240, 640, 440], [278, 46, 358, 106]], [[360, 240, 660, 440], [255, 45, 335, 105]], [[380, 240, 680, 440], [232, 44, 312, 104]], [[400, 240, 700, 440], [209, 43, 289, 103]], [[420, 240, 720, 440], [186, 42, 266, 102]], [[440, 240, 740, 440], [163, 41, 243, 101]], [[460, 240, 760, 440], [140, 40, 220, 100]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_deer_approach.json b/examples/wanvideo/model_inference/example_deer_approach.json new file mode 100644 index 0000000000000000000000000000000000000000..38f7a58b35ed75a1e5d826a1a02c610587c30218 --- /dev/null +++ b/examples/wanvideo/model_inference/example_deer_approach.json @@ -0,0 +1 @@ +[[[340, 220, 420, 320]], [[335, 216, 425, 326]], [[330, 212, 430, 332]], [[325, 208, 435, 338]], [[320, 204, 440, 344]], [[315, 200, 445, 350]], [[310, 196, 450, 356]], [[305, 192, 455, 362]], [[300, 188, 460, 368]], [[295, 184, 465, 374]], [[290, 180, 470, 380]], [[285, 176, 475, 386]], [[280, 172, 480, 392]], [[275, 168, 485, 398]], [[270, 164, 490, 404]], [[265, 160, 495, 410]], [[260, 156, 500, 416]], [[255, 152, 505, 422]], [[250, 148, 510, 428]], [[245, 144, 515, 434]], [[240, 140, 520, 440]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_four_moving_bbox.json b/examples/wanvideo/model_inference/example_four_moving_bbox.json new file mode 100644 index 0000000000000000000000000000000000000000..6a194d1aa343ca2161eee527df3e68fee9b5befb --- /dev/null +++ b/examples/wanvideo/model_inference/example_four_moving_bbox.json @@ -0,0 +1 @@ +[[[30, 100, 210, 450], [230, 100, 410, 450], [430, 100, 610, 450], [630, 100, 810, 450]], [[31, 100, 211, 450], [231, 100, 411, 450], [431, 100, 611, 450], [631, 100, 811, 450]], [[32, 100, 212, 450], [232, 100, 412, 450], [432, 100, 612, 450], [632, 100, 812, 450]], [[33, 100, 213, 450], [233, 100, 413, 450], [433, 100, 613, 450], [633, 100, 813, 450]], [[34, 100, 214, 450], [234, 100, 414, 450], [434, 100, 614, 450], [634, 100, 814, 450]], [[35, 100, 215, 450], [235, 100, 415, 450], [435, 100, 615, 450], [635, 100, 815, 450]], [[36, 100, 216, 450], [236, 100, 416, 450], [436, 100, 616, 450], [636, 100, 816, 450]], [[37, 100, 217, 450], [237, 100, 417, 450], [437, 100, 617, 450], [637, 100, 817, 450]], [[38, 100, 218, 450], [238, 100, 418, 450], [438, 100, 618, 450], [638, 100, 818, 450]], [[39, 100, 219, 450], [239, 100, 419, 450], [439, 100, 619, 450], [639, 100, 819, 450]], [[40, 100, 220, 450], [240, 100, 420, 450], [440, 100, 620, 450], [640, 100, 820, 450]], [[41, 100, 221, 450], [241, 100, 421, 450], [441, 100, 621, 450], [641, 100, 821, 450]], [[42, 100, 222, 450], [242, 100, 422, 450], [442, 100, 622, 450], [642, 100, 822, 450]], [[43, 100, 223, 450], [243, 100, 423, 450], [443, 100, 623, 450], [643, 100, 823, 450]], [[44, 100, 224, 450], [244, 100, 424, 450], [444, 100, 624, 450], [644, 100, 824, 450]], [[45, 100, 225, 450], [245, 100, 425, 450], [445, 100, 625, 450], [645, 100, 825, 450]], [[46, 100, 226, 450], [246, 100, 426, 450], [446, 100, 626, 450], [646, 100, 826, 450]], [[47, 100, 227, 450], [247, 100, 427, 450], [447, 100, 627, 450], [647, 100, 827, 450]], [[48, 100, 228, 450], [248, 100, 428, 450], [448, 100, 628, 450], [648, 100, 828, 450]], [[49, 100, 229, 450], [249, 100, 429, 450], [449, 100, 629, 450], [649, 100, 829, 450]], [[50, 100, 230, 450], [250, 100, 430, 450], [450, 100, 630, 450], [650, 100, 830, 450]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_four_pigeons_orbit.json b/examples/wanvideo/model_inference/example_four_pigeons_orbit.json new file mode 100644 index 0000000000000000000000000000000000000000..47f0a4c0dc2211fcd8f6c14f342e12dd622a8ada --- /dev/null +++ b/examples/wanvideo/model_inference/example_four_pigeons_orbit.json @@ -0,0 +1 @@ +[[[160, 80, 240, 160], [560, 80, 640, 160], [560, 320, 640, 400], [160, 320, 240, 400]], [[180, 80, 260, 160], [560, 92, 640, 172], [540, 320, 620, 400], [160, 308, 240, 388]], [[200, 80, 280, 160], [560, 104, 640, 184], [520, 320, 600, 400], [160, 296, 240, 376]], [[220, 80, 300, 160], [560, 116, 640, 196], [500, 320, 580, 400], [160, 284, 240, 364]], [[240, 80, 320, 160], [560, 128, 640, 208], [480, 320, 560, 400], [160, 272, 240, 352]], [[260, 80, 340, 160], [560, 140, 640, 220], [460, 320, 540, 400], [160, 260, 240, 340]], [[280, 80, 360, 160], [560, 152, 640, 232], [440, 320, 520, 400], [160, 248, 240, 328]], [[300, 80, 380, 160], [560, 164, 640, 244], [420, 320, 500, 400], [160, 236, 240, 316]], [[320, 80, 400, 160], [560, 176, 640, 256], [400, 320, 480, 400], [160, 224, 240, 304]], [[340, 80, 420, 160], [560, 188, 640, 268], [380, 320, 460, 400], [160, 212, 240, 292]], [[360, 80, 440, 160], [560, 200, 640, 280], [360, 320, 440, 400], [160, 200, 240, 280]], [[380, 80, 460, 160], [560, 212, 640, 292], [340, 320, 420, 400], [160, 188, 240, 268]], [[400, 80, 480, 160], [560, 224, 640, 304], [320, 320, 400, 400], [160, 176, 240, 256]], [[420, 80, 500, 160], [560, 236, 640, 316], [300, 320, 380, 400], [160, 164, 240, 244]], [[440, 80, 520, 160], [560, 248, 640, 328], [280, 320, 360, 400], [160, 152, 240, 232]], [[460, 80, 540, 160], [560, 260, 640, 340], [260, 320, 340, 400], [160, 140, 240, 220]], [[480, 80, 560, 160], [560, 272, 640, 352], [240, 320, 320, 400], [160, 128, 240, 208]], [[500, 80, 580, 160], [560, 284, 640, 364], [220, 320, 300, 400], [160, 116, 240, 196]], [[520, 80, 600, 160], [560, 296, 640, 376], [200, 320, 280, 400], [160, 104, 240, 184]], [[540, 80, 620, 160], [560, 308, 640, 388], [180, 320, 260, 400], [160, 92, 240, 172]], [[560, 80, 640, 160], [560, 320, 640, 400], [160, 320, 240, 400], [160, 80, 240, 160]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_moving_bbox.json b/examples/wanvideo/model_inference/example_moving_bbox.json new file mode 100644 index 0000000000000000000000000000000000000000..99de91c8217e80c7182617e0174e8ea406e2ea91 --- /dev/null +++ b/examples/wanvideo/model_inference/example_moving_bbox.json @@ -0,0 +1,24 @@ +[ + [{"x0": 100, "y0": 150, "x1": 350, "y1": 400}], + [{"x0": 120, "y0": 145, "x1": 370, "y1": 395}], + [{"x0": 140, "y0": 140, "x1": 390, "y1": 390}], + [{"x0": 160, "y0": 135, "x1": 410, "y1": 385}], + [{"x0": 180, "y0": 130, "x1": 430, "y1": 380}], + [{"x0": 200, "y0": 125, "x1": 450, "y1": 375}], + [{"x0": 220, "y0": 120, "x1": 470, "y1": 370}], + [{"x0": 240, "y0": 115, "x1": 490, "y1": 365}], + [{"x0": 260, "y0": 110, "x1": 510, "y1": 360}], + [{"x0": 280, "y0": 105, "x1": 530, "y1": 355}], + [{"x0": 300, "y0": 100, "x1": 550, "y1": 350}], + [{"x0": 320, "y0": 105, "x1": 570, "y1": 355}], + [{"x0": 340, "y0": 110, "x1": 590, "y1": 360}], + [{"x0": 360, "y0": 115, "x1": 610, "y1": 365}], + [{"x0": 380, "y0": 120, "x1": 630, "y1": 370}], + [{"x0": 400, "y0": 125, "x1": 650, "y1": 375}], + [{"x0": 420, "y0": 130, "x1": 670, "y1": 380}], + [{"x0": 440, "y0": 135, "x1": 690, "y1": 385}], + [{"x0": 460, "y0": 140, "x1": 710, "y1": 390}], + [{"x0": 480, "y0": 145, "x1": 730, "y1": 395}], + [{"x0": 500, "y0": 150, "x1": 750, "y1": 400}] +] + diff --git a/examples/wanvideo/model_inference/example_multi_moving_bbox.json b/examples/wanvideo/model_inference/example_multi_moving_bbox.json new file mode 100644 index 0000000000000000000000000000000000000000..1debcb1ee87e7556783494e0901491714cd2b98b --- /dev/null +++ b/examples/wanvideo/model_inference/example_multi_moving_bbox.json @@ -0,0 +1 @@ +[[[40, 240, 260, 440], [300, 220, 500, 420], [560, 151, 680, 231]], [[58, 240, 278, 440], [310, 220, 510, 420], [566, 153, 686, 233]], [[76, 240, 296, 440], [320, 220, 520, 420], [572, 154, 692, 234]], [[94, 240, 314, 440], [330, 220, 530, 420], [578, 156, 698, 236]], [[112, 240, 332, 440], [340, 220, 540, 420], [584, 157, 704, 237]], [[130, 240, 350, 440], [350, 220, 550, 420], [590, 158, 710, 238]], [[148, 240, 368, 440], [360, 220, 560, 420], [596, 159, 716, 239]], [[166, 240, 386, 440], [370, 220, 570, 420], [602, 159, 722, 239]], [[184, 240, 404, 440], [380, 220, 580, 420], [608, 160, 728, 240]], [[202, 240, 422, 440], [390, 220, 590, 420], [614, 160, 734, 240]], [[220, 240, 440, 440], [400, 220, 600, 420], [620, 160, 740, 240]], [[238, 240, 458, 440], [410, 220, 610, 420], [626, 160, 746, 240]], [[256, 240, 476, 440], [420, 220, 620, 420], [632, 160, 752, 240]], [[274, 240, 494, 440], [430, 220, 630, 420], [638, 160, 758, 240]], [[292, 240, 512, 440], [440, 220, 640, 420], [644, 159, 764, 239]], [[310, 240, 530, 440], [450, 220, 650, 420], [650, 159, 770, 239]], [[328, 240, 548, 440], [460, 220, 660, 420], [656, 158, 776, 238]], [[346, 240, 566, 440], [470, 220, 670, 420], [662, 157, 782, 237]], [[364, 240, 584, 440], [480, 220, 680, 420], [668, 156, 788, 236]], [[382, 240, 602, 440], [490, 220, 690, 420], [674, 154, 794, 234]], [[400, 240, 620, 440], [500, 220, 700, 420], [680, 153, 800, 233]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_single_sweep_big_motion.json b/examples/wanvideo/model_inference/example_single_sweep_big_motion.json new file mode 100644 index 0000000000000000000000000000000000000000..9d5bfd84b3f484e40c74f6cdcdf93696dbfb6182 --- /dev/null +++ b/examples/wanvideo/model_inference/example_single_sweep_big_motion.json @@ -0,0 +1 @@ +[[{"x0": 60, "y0": 110, "x1": 280, "y1": 370}], [{"x0": 84, "y0": 121, "x1": 304, "y1": 381}], [{"x0": 108, "y0": 131, "x1": 328, "y1": 391}], [{"x0": 132, "y0": 138, "x1": 352, "y1": 398}], [{"x0": 156, "y0": 143, "x1": 376, "y1": 403}], [{"x0": 180, "y0": 145, "x1": 400, "y1": 405}], [{"x0": 204, "y0": 143, "x1": 424, "y1": 403}], [{"x0": 228, "y0": 138, "x1": 448, "y1": 398}], [{"x0": 252, "y0": 131, "x1": 472, "y1": 391}], [{"x0": 276, "y0": 121, "x1": 496, "y1": 381}], [{"x0": 300, "y0": 110, "x1": 520, "y1": 370}], [{"x0": 324, "y0": 99, "x1": 544, "y1": 359}], [{"x0": 348, "y0": 89, "x1": 568, "y1": 349}], [{"x0": 372, "y0": 82, "x1": 592, "y1": 342}], [{"x0": 396, "y0": 77, "x1": 616, "y1": 337}], [{"x0": 420, "y0": 75, "x1": 640, "y1": 335}], [{"x0": 444, "y0": 77, "x1": 664, "y1": 337}], [{"x0": 468, "y0": 82, "x1": 688, "y1": 342}], [{"x0": 492, "y0": 89, "x1": 712, "y1": 349}], [{"x0": 516, "y0": 99, "x1": 736, "y1": 359}], [{"x0": 540, "y0": 110, "x1": 760, "y1": 370}]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_three_diagonal_big_motion.json b/examples/wanvideo/model_inference/example_three_diagonal_big_motion.json new file mode 100644 index 0000000000000000000000000000000000000000..4f5b420186c9c039f06c4ca3ba61a99bb56a6958 --- /dev/null +++ b/examples/wanvideo/model_inference/example_three_diagonal_big_motion.json @@ -0,0 +1 @@ +[[[60, 60, 240, 280], [330, 220, 530, 420], [620, 80, 740, 200]], [[76, 68, 256, 288], [318, 214, 518, 414], [602, 90, 722, 210]], [[92, 76, 272, 296], [306, 208, 506, 408], [584, 100, 704, 220]], [[108, 84, 288, 304], [294, 202, 494, 402], [566, 110, 686, 230]], [[124, 92, 304, 312], [282, 196, 482, 396], [548, 120, 668, 240]], [[140, 100, 320, 320], [270, 190, 470, 390], [530, 130, 650, 250]], [[156, 108, 336, 328], [258, 184, 458, 384], [512, 140, 632, 260]], [[172, 116, 352, 336], [246, 178, 446, 378], [494, 150, 614, 270]], [[188, 124, 368, 344], [234, 172, 434, 372], [476, 160, 596, 280]], [[204, 132, 384, 352], [222, 166, 422, 366], [458, 170, 578, 290]], [[220, 140, 400, 360], [210, 160, 410, 360], [440, 180, 560, 300]], [[236, 148, 416, 368], [198, 154, 398, 354], [422, 190, 542, 310]], [[252, 156, 432, 376], [186, 148, 386, 348], [404, 200, 524, 320]], [[268, 164, 448, 384], [174, 142, 374, 342], [386, 210, 506, 330]], [[284, 172, 464, 392], [162, 136, 362, 336], [368, 220, 488, 340]], [[300, 180, 480, 400], [150, 130, 350, 330], [350, 230, 470, 350]], [[316, 188, 496, 408], [138, 124, 338, 324], [332, 240, 452, 360]], [[332, 196, 512, 416], [126, 118, 326, 318], [314, 250, 434, 370]], [[348, 204, 528, 424], [114, 112, 314, 312], [296, 260, 416, 380]], [[364, 212, 544, 432], [102, 106, 302, 306], [278, 270, 398, 390]], [[380, 220, 560, 440], [90, 100, 290, 300], [260, 280, 380, 400]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_two_crossing_big_motion.json b/examples/wanvideo/model_inference/example_two_crossing_big_motion.json new file mode 100644 index 0000000000000000000000000000000000000000..dabf781c199d0b2e568301864cdc90d700b0636a --- /dev/null +++ b/examples/wanvideo/model_inference/example_two_crossing_big_motion.json @@ -0,0 +1 @@ +[[[40, 80, 240, 380], [560, 100, 760, 420]], [[62, 80, 262, 380], [538, 100, 738, 420]], [[84, 80, 284, 380], [516, 100, 716, 420]], [[106, 80, 306, 380], [494, 100, 694, 420]], [[128, 80, 328, 380], [472, 100, 672, 420]], [[150, 80, 350, 380], [450, 100, 650, 420]], [[172, 80, 372, 380], [428, 100, 628, 420]], [[194, 80, 394, 380], [406, 100, 606, 420]], [[216, 80, 416, 380], [384, 100, 584, 420]], [[238, 80, 438, 380], [362, 100, 562, 420]], [[260, 80, 460, 380], [340, 100, 540, 420]], [[282, 80, 482, 380], [318, 100, 518, 420]], [[304, 80, 504, 380], [296, 100, 496, 420]], [[326, 80, 526, 380], [274, 100, 474, 420]], [[348, 80, 548, 380], [252, 100, 452, 420]], [[370, 80, 570, 380], [230, 100, 430, 420]], [[392, 80, 592, 380], [208, 100, 408, 420]], [[414, 80, 614, 380], [186, 100, 386, 420]], [[436, 80, 636, 380], [164, 100, 364, 420]], [[458, 80, 658, 380], [142, 100, 342, 420]], [[480, 80, 680, 380], [120, 100, 320, 420]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_two_scooters_crossing.json b/examples/wanvideo/model_inference/example_two_scooters_crossing.json new file mode 100644 index 0000000000000000000000000000000000000000..31f468c14c54b4ee95784cda8b610cecd09510bb --- /dev/null +++ b/examples/wanvideo/model_inference/example_two_scooters_crossing.json @@ -0,0 +1 @@ +[[[40, 260, 220, 420], [580, 80, 760, 240]], [[64, 258, 244, 418], [555, 79, 735, 239]], [[88, 257, 268, 417], [530, 78, 710, 238]], [[112, 256, 292, 416], [505, 77, 685, 237]], [[136, 254, 316, 414], [480, 76, 660, 236]], [[160, 252, 340, 412], [455, 75, 635, 235]], [[184, 251, 364, 411], [430, 74, 610, 234]], [[208, 250, 388, 410], [405, 73, 585, 233]], [[232, 248, 412, 408], [380, 72, 560, 232]], [[256, 246, 436, 406], [355, 71, 535, 231]], [[280, 245, 460, 405], [330, 70, 510, 230]], [[304, 244, 484, 404], [305, 69, 485, 229]], [[328, 242, 508, 402], [280, 68, 460, 228]], [[352, 240, 532, 400], [255, 67, 435, 227]], [[376, 239, 556, 399], [230, 66, 410, 226]], [[400, 238, 580, 398], [205, 65, 385, 225]], [[424, 236, 604, 396], [180, 64, 360, 224]], [[448, 234, 628, 394], [155, 63, 335, 223]], [[472, 233, 652, 393], [130, 62, 310, 222]], [[496, 232, 676, 392], [105, 61, 285, 221]], [[520, 230, 700, 390], [80, 60, 260, 220]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/example_two_students_drone.json b/examples/wanvideo/model_inference/example_two_students_drone.json new file mode 100644 index 0000000000000000000000000000000000000000..cce56c4510e287a842f22da73e68fc3938eabb8b --- /dev/null +++ b/examples/wanvideo/model_inference/example_two_students_drone.json @@ -0,0 +1 @@ +[[[60, 140, 260, 440], [520, 140, 720, 440], [380, 60, 440, 120]], [[70, 140, 270, 440], [511, 140, 711, 440], [381, 61, 441, 121]], [[80, 140, 280, 440], [502, 140, 702, 440], [382, 62, 442, 122]], [[90, 140, 290, 440], [493, 140, 693, 440], [383, 63, 443, 123]], [[100, 140, 300, 440], [484, 140, 684, 440], [384, 64, 444, 124]], [[110, 140, 310, 440], [475, 140, 675, 440], [385, 65, 445, 125]], [[120, 140, 320, 440], [466, 140, 666, 440], [386, 66, 446, 126]], [[130, 140, 330, 440], [457, 140, 657, 440], [387, 67, 447, 127]], [[140, 140, 340, 440], [448, 140, 648, 440], [388, 68, 448, 128]], [[150, 140, 350, 440], [439, 140, 639, 440], [389, 69, 449, 129]], [[160, 140, 360, 440], [430, 140, 630, 440], [390, 70, 450, 130]], [[170, 140, 370, 440], [421, 140, 621, 440], [391, 71, 451, 131]], [[180, 140, 380, 440], [412, 140, 612, 440], [392, 72, 452, 132]], [[190, 140, 390, 440], [403, 140, 603, 440], [393, 73, 453, 133]], [[200, 140, 400, 440], [394, 140, 594, 440], [394, 74, 454, 134]], [[210, 140, 410, 440], [385, 140, 585, 440], [395, 75, 455, 135]], [[220, 140, 420, 440], [376, 140, 576, 440], [396, 76, 456, 136]], [[230, 140, 430, 440], [367, 140, 567, 440], [397, 77, 457, 137]], [[240, 140, 440, 440], [358, 140, 558, 440], [398, 78, 458, 138]], [[250, 140, 450, 440], [349, 140, 549, 440], [399, 79, 459, 139]], [[260, 140, 460, 440], [340, 140, 540, 440], [400, 80, 460, 140]]] \ No newline at end of file diff --git a/examples/wanvideo/model_inference/instanceV.py b/examples/wanvideo/model_inference/instanceV.py new file mode 100644 index 0000000000000000000000000000000000000000..979e1bb8bb1d89e5e4da920d7aed214ad7e8d65b --- /dev/null +++ b/examples/wanvideo/model_inference/instanceV.py @@ -0,0 +1,42 @@ +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video_instanceV import WanVideoPipeline, ModelConfig + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +H, W, num_frames = 480, 832, 81 +# latent 时间长度是 (num_frames-1)//4+1 = 21 (NoiseInitializer 里就是这么算的):contentReference[oaicite:15]{index=15} +f_lat = (num_frames - 1) // 4 + 1 + +# 例子:只有 1 个 instance(小狗),给一个大概的 bbox(全程固定) +dog_box = (0.20 * W, 0.35 * H, 0.80 * W, 0.92 * H) # (x0,y0,x1,y1) +instance_bboxes = [[dog_box] for _ in range(f_lat)] + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,最差质量,低质量,JPEG压缩残留,畸形,毁容,多余的手指,背景人很多", + seed=0, + height=H, width=W, num_frames=num_frames, + tiled=True, + + # ===== InstanceV new args ===== + instance_prompts=[ + "一只棕黄色、毛发柔软的小狗,奔跑时耳朵竖起,表情欢快", + ], + instance_bboxes=instance_bboxes, + + # ===== SAUG (optional) ===== + saug_scale=0.6, # 你可以从 0.3~1.0 试 + saug_drop_prob=0.0, +) + +save_video(video, "video_1_instancev.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/instanceV_inference.py b/examples/wanvideo/model_inference/instanceV_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3cdfc902d56f9cfd6937c1ec57f75ec24002df67 --- /dev/null +++ b/examples/wanvideo/model_inference/instanceV_inference.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +""" +InstanceV 推理脚本 + +用法: + python instanceV_inference.py \ + --checkpoint path/to/step-15400.safetensors \ + --prompt "A scene with two objects" \ + --instance_prompts "object1 description" "object2 description" \ + --bboxes_json bboxes.json \ + --output output_video.mp4 +""" + +import torch +import json +import argparse +import os +import sys +import types +from pathlib import Path +from PIL import ImageDraw + +# 添加项目根目录 +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video_instanceV import WanVideoPipeline, ModelConfig +from diffsynth.models.wan_video_dit_instancev import ( + WanModel, DiTBlock, + SharedTimestepAdaptivePromptEnhancement, + InstanceAwareMaskedCrossAttention, +) +from safetensors.torch import load_file + + +def parse_args(): + parser = argparse.ArgumentParser(description="InstanceV Inference") + + # 模型配置 + parser.add_argument("--checkpoint", type=str, default=None, + help="InstanceV checkpoint 路径 (e.g., step-15400.safetensors)") + parser.add_argument("--model_id", type=str, default="Wan-AI/Wan2.1-T2V-1.3B", + help="Base model ID") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="bfloat16", + choices=["bfloat16", "float16", "float32"]) + parser.add_argument("--low_vram", action="store_true", + help="启用低显存模式 (模型 offload 到 disk/CPU)") + parser.add_argument("--dit_path", type=str, default=None, + help="Local path to diffusion_pytorch_model*.safetensors (no download).") + parser.add_argument("--text_encoder_path", type=str, default=None, + help="Local path to models_t5_umt5-xxl-enc-bf16.safetensors (no download).") + parser.add_argument("--vae_path", type=str, default=None, + help="Local path to Wan2.1_VAE.safetensors (no download).") + parser.add_argument("--tokenizer_path", type=str, default=None, + help="Local tokenizer directory (e.g., google/umt5-xxl).") + + # 视频配置 + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=832) + parser.add_argument("--num_frames", type=int, default=81) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--cfg_scale", type=float, default=5.0) + parser.add_argument("--seed", type=int, default=42) + + # 提示词 + parser.add_argument("--prompt", type=str, required=True, + help="全局场景描述") + parser.add_argument("--negative_prompt", type=str, + default="色调艳丽,过曝,静态,细节模糊不清,字幕,最差质量,低质量,JPEG压缩残留,畸形,毁容", + help="负面提示词") + parser.add_argument("--instance_prompts", type=str, nargs="+", required=True, + help="每个实例的描述 (空格分隔)") + + # Bbox 配置 + parser.add_argument("--bboxes_json", type=str, default=None, + help="JSON 文件,格式: [[{x0,y0,x1,y1}, ...], ...] 每帧每实例的 bbox") + parser.add_argument("--static_bboxes", type=str, nargs="+", default=None, + help="静态 bbox (所有帧相同),格式: 'x0,y0,x1,y1' 每个实例一个") + + # SAUG 配置 + parser.add_argument("--saug_scale", type=float, default=0.6, + help="SAUG scale (推理时推荐 0.3~1.0)") + parser.add_argument("--saug_drop_prob", type=float, default=0.0, + help="SAUG dropout (推理时通常为 0)") + + # 输出 + parser.add_argument("--output", type=str, default="output_instancev.mp4") + parser.add_argument("--fps", type=int, default=15) + + return parser.parse_args() + + +def get_underlying_module(module): + """获取底层模块(解包 VRAM wrapper)""" + if hasattr(module, 'module'): + return module.module + return module + + +def load_instancev_checkpoint(pipe, checkpoint_path): + """加载训练好的 InstanceV checkpoint""" + print(f"Loading InstanceV checkpoint from: {checkpoint_path}") + + state_dict = load_file(checkpoint_path) + + dit = get_underlying_module(pipe.dit) + dit.enable_instancev = True + block0 = get_underlying_module(dit.blocks[0]) + dim = block0.self_attn.q.out_features // block0.self_attn.num_heads * block0.self_attn.num_heads + num_heads = block0.self_attn.num_heads + num_layers = len(dit.blocks) + + # 1) 添加 STAPE 模块(如果还没有) + if not hasattr(dit, 'stape') or dit.stape is None: + dit.stape = SharedTimestepAdaptivePromptEnhancement(dim=dim, num_heads=num_heads) + dit.stape = dit.stape.to(device=pipe.device, dtype=pipe.torch_dtype) + print(f" -> Added STAPE module (dim={dim}, num_heads={num_heads})") + + # 2) 为每个 block 添加 IMCA / mv / norm_imca(如果还没有) + for block_idx in range(num_layers): + block_wrapper = dit.blocks[block_idx] + block = get_underlying_module(block_wrapper) + block.enable_instancev = True + block.stape = dit.stape + + if not hasattr(block, 'imca') or block.imca is None: + block.imca = InstanceAwareMaskedCrossAttention(dim=dim, num_heads=num_heads) + block.imca = block.imca.to(device=pipe.device, dtype=pipe.torch_dtype) + # 尝试从 cross-attn 初始化(与训练一致) + try: + block.imca.attn.q.load_state_dict(block.cross_attn.q.state_dict()) + block.imca.attn.k.load_state_dict(block.cross_attn.k.state_dict()) + block.imca.attn.v.load_state_dict(block.cross_attn.v.state_dict()) + block.imca.attn.o.load_state_dict(block.cross_attn.o.state_dict()) + block.imca.attn.norm_q.load_state_dict(block.cross_attn.norm_q.state_dict()) + block.imca.attn.norm_k.load_state_dict(block.cross_attn.norm_k.state_dict()) + except Exception: + pass + + if not hasattr(block, 'mv') or block.mv is None: + block.mv = torch.nn.Parameter(torch.zeros(1, device=pipe.device, dtype=pipe.torch_dtype)) + + if not hasattr(block, 'norm_imca') or block.norm_imca is None: + block.norm_imca = torch.nn.LayerNorm(dim, elementwise_affine=False) + block.norm_imca = block.norm_imca.to(device=pipe.device, dtype=pipe.torch_dtype) + + # 替换 forward 方法 + block.forward = types.MethodType(DiTBlock.forward, block) + + print(f" -> Added IMCA, mv, norm_imca to {num_layers} blocks") + + # 3) 加载权重 + # checkpoint 保存时用了 remove_prefix="pipe.dit.",所以 key 不带前缀 + loaded_keys = [] + missing_keys = [] + + for key, value in state_dict.items(): + # 尝试直接加载 + try: + parts = key.split('.') + obj = dit + for part in parts[:-1]: + if part.isdigit(): + obj = get_underlying_module(obj[int(part)]) + else: + obj = get_underlying_module(getattr(obj, part)) + + param_name = parts[-1] + if hasattr(obj, param_name): + param = getattr(obj, param_name) + if isinstance(param, torch.nn.Parameter): + param.data.copy_(value.to(device=pipe.device, dtype=pipe.torch_dtype)) + elif isinstance(param, torch.Tensor): + param.copy_(value.to(device=pipe.device, dtype=pipe.torch_dtype)) + else: + setattr(obj, param_name, value.to(device=pipe.device, dtype=pipe.torch_dtype)) + loaded_keys.append(key) + else: + missing_keys.append(key) + except Exception as e: + missing_keys.append(f"{key}: {e}") + + print(f" -> Loaded {len(loaded_keys)} keys") + if missing_keys: + print(f" -> Missing/skipped {len(missing_keys)} keys") + + # 设置为推理模式 + dit.eval() + + return pipe + + +def parse_bboxes(args): + """解析 bbox 配置""" + f_lat = (args.num_frames - 1) // 4 + 1 + + if args.bboxes_json: + # 从 JSON 文件加载 + with open(args.bboxes_json, 'r') as f: + bboxes_data = json.load(f) + + # 格式: [[{x0, y0, x1, y1}, ...], ...] 每帧每实例 + # 转换为: [[bbox_inst0, bbox_inst1, ...], ...] 每帧 + bboxes = [] + for frame_data in bboxes_data: + frame_bboxes = [] + for box in frame_data: + if box is None: + frame_bboxes.append(None) + elif isinstance(box, dict): + frame_bboxes.append((box["x0"], box["y0"], box["x1"], box["y1"])) + elif isinstance(box, list) and len(box) == 4: + frame_bboxes.append(tuple(box)) + else: + frame_bboxes.append(None) + bboxes.append(frame_bboxes) + + # 如果帧数不匹配,需要采样 + if len(bboxes) != f_lat: + import numpy as np + indices = np.linspace(0, len(bboxes) - 1, f_lat, dtype=int) + bboxes = [bboxes[i] for i in indices] + + return bboxes + + elif args.static_bboxes: + # 静态 bbox,所有帧相同 + bboxes_per_instance = [] + for bbox_str in args.static_bboxes: + coords = [float(x.strip()) for x in bbox_str.split(',')] + if len(coords) == 4: + bboxes_per_instance.append(tuple(coords)) + else: + raise ValueError(f"Invalid bbox format: {bbox_str}, expected 'x0,y0,x1,y1'") + + # 扩展到所有帧 + bboxes = [bboxes_per_instance for _ in range(f_lat)] + return bboxes + + else: + # 默认:每个实例占据不同区域 + nins = len(args.instance_prompts) + W, H = args.width, args.height + + # 均匀分布 + default_bboxes = [] + for i in range(nins): + x0 = (i / nins) * W * 0.8 + W * 0.1 + x1 = ((i + 1) / nins) * W * 0.8 + W * 0.1 + y0 = H * 0.2 + y1 = H * 0.8 + default_bboxes.append((x0, y0, x1, y1)) + + bboxes = [default_bboxes for _ in range(f_lat)] + print(f"[Warning] No bboxes provided, using default layout: {default_bboxes}") + return bboxes + + +def upsample_bboxes_to_frames(bboxes, num_frames): + """Upsample bbox list from f_lat to num_frames.""" + if not bboxes: + return [] + if len(bboxes) == num_frames: + return bboxes + if num_frames <= 1: + return [bboxes[0]] + max_idx = len(bboxes) - 1 + mapped = [] + for i in range(num_frames): + idx = int(round(i * max_idx / (num_frames - 1))) + mapped.append(bboxes[idx]) + return mapped + + +def draw_bboxes_on_video(video, bboxes, num_frames): + """Draw instance bboxes on output frames.""" + if not video or not bboxes: + return video + bboxes_per_frame = upsample_bboxes_to_frames(bboxes, num_frames) + colors = ["#FF4D4D", "#3FA9F5", "#7ED957", "#FFB347", "#B46BFF", "#00C2A8", "#FFD93D"] + for frame_idx, frame in enumerate(video): + draw = ImageDraw.Draw(frame) + frame_bboxes = bboxes_per_frame[min(frame_idx, len(bboxes_per_frame) - 1)] + for inst_id, box in enumerate(frame_bboxes): + if box is None: + continue + x0, y0, x1, y1 = box + color = colors[inst_id % len(colors)] + draw.rectangle([x0, y0, x1, y1], outline=color, width=3) + return video + + +def main(): + args = parse_args() + + # 设置 dtype + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + torch_dtype = dtype_map[args.dtype] + + print("=" * 60) + print("InstanceV Inference") + print("=" * 60) + print(f"Prompt: {args.prompt}") + print(f"Instance Prompts: {args.instance_prompts}") + print(f"Resolution: {args.width}x{args.height}, {args.num_frames} frames") + print(f"Checkpoint: {args.checkpoint}") + print("=" * 60) + + # 1) 加载基础 Pipeline + print("\n[1/4] Loading base model...") + + use_local = any([args.dit_path, args.text_encoder_path, args.vae_path, args.tokenizer_path]) + if use_local: + missing = [] + if not args.dit_path: + missing.append("--dit_path") + if not args.text_encoder_path: + missing.append("--text_encoder_path") + if not args.vae_path: + missing.append("--vae_path") + if not args.tokenizer_path: + missing.append("--tokenizer_path") + if missing: + raise ValueError(f"Local path mode requires: {', '.join(missing)}") + + if args.low_vram: + # 低显存模式:模型 offload 到 disk/CPU + print(" -> Low VRAM mode enabled (disk offload)") + vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch_dtype, + "onload_device": "cpu", + "preparing_dtype": torch_dtype, + "preparing_device": "cuda", + "computation_dtype": torch_dtype, + "computation_device": "cuda", + } + # 计算可用显存 + try: + vram_limit = torch.cuda.mem_get_info(args.device)[0] / (1024 ** 3) - 2 # 剩余显存 - 2GB buffer + except: + vram_limit = 8 # 默认 8GB + print(f" -> VRAM limit: {vram_limit:.1f} GB") + + if use_local: + model_configs = [ + ModelConfig(path=args.dit_path, **vram_config), + ModelConfig(path=args.text_encoder_path, **vram_config), + ModelConfig(path=args.vae_path, **vram_config), + ] + tokenizer_config = ModelConfig(path=args.tokenizer_path) + else: + model_configs = [ + ModelConfig(model_id=args.model_id, origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ] + tokenizer_config = ModelConfig(model_id=args.model_id, origin_file_pattern="google/umt5-xxl/") + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch_dtype, + device=args.device, + model_configs=model_configs, + tokenizer_config=tokenizer_config, + vram_limit=vram_limit, + ) + else: + if use_local: + model_configs = [ + ModelConfig(path=args.dit_path), + ModelConfig(path=args.text_encoder_path), + ModelConfig(path=args.vae_path), + ] + tokenizer_config = ModelConfig(path=args.tokenizer_path) + else: + model_configs = [ + ModelConfig(model_id=args.model_id, origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors"), + ] + tokenizer_config = ModelConfig(model_id=args.model_id, origin_file_pattern="google/umt5-xxl/") + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch_dtype, + device=args.device, + model_configs=model_configs, + tokenizer_config=tokenizer_config, + ) + + # 2) 加载 InstanceV checkpoint + if args.checkpoint: + print("\n[2/4] Loading InstanceV checkpoint...") + pipe = load_instancev_checkpoint(pipe, args.checkpoint) + else: + print("\n[2/4] No checkpoint provided, using base model (InstanceV modules will be random)") + + # 3) 解析 bboxes + print("\n[3/4] Parsing bboxes...") + instance_bboxes = parse_bboxes(args) + print(f" -> {len(instance_bboxes)} frames, {len(instance_bboxes[0])} instances per frame") + + # 4) 生成视频 + print("\n[4/4] Generating video...") + video = pipe( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + cfg_scale=args.cfg_scale, + tiled=True, + + # ===== InstanceV ===== + instance_prompts=args.instance_prompts, + instance_bboxes=instance_bboxes, + saug_scale=args.saug_scale, + saug_drop_prob=args.saug_drop_prob, + ) + + # draw bboxes on the output video for visualization + video = draw_bboxes_on_video(video, instance_bboxes, args.num_frames) + + # 保存 + save_video(video, args.output, fps=args.fps, quality=5) + print(f"\n✅ Video saved to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/wanvideo/model_inference/instanceV_simple.py b/examples/wanvideo/model_inference/instanceV_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..b531faa6064c764c36ef6f55a8289f7f636088e5 --- /dev/null +++ b/examples/wanvideo/model_inference/instanceV_simple.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +InstanceV 简化推理脚本 + +用法: + CUDA_VISIBLE_DEVICES=0 python instanceV_simple.py +""" + +import torch +import types +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video_instanceV import WanVideoPipeline, ModelConfig +from diffsynth.models.wan_video_dit_instancev import ( + DiTBlock, + SharedTimestepAdaptivePromptEnhancement, + InstanceAwareMaskedCrossAttention, +) +from safetensors.torch import load_file + +# ========== 配置 ========== +CHECKPOINT = "models/train/instancev/step-15400.safetensors" +OUTPUT = "outputs/instancev_demo.mp4" + +PROMPT = "纪实摄影风格画面,绿茵茵的草地,阳光明媚,一只活泼的小狗在快速奔跑。中景侧面移动视角。" +NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,最差质量,低质量" +INSTANCE_PROMPTS = ["一只棕黄色毛发柔软的小狗,奔跑时耳朵竖起,表情欢快"] + +HEIGHT, WIDTH, NUM_FRAMES = 480, 832, 81 +SEED = 42 +SAUG_SCALE = 0.6 + +# 小狗从左到右移动的 bbox 序列(21 帧 = f_lat) +def generate_moving_bbox(f_lat, w, h): + """生成从左到右移动的 bbox""" + bboxes = [] + for t in range(f_lat): + progress = t / (f_lat - 1) # 0 -> 1 + x0 = int(100 + progress * 400) + y0 = int(150 - progress * 50) + x1 = x0 + 250 + y1 = y0 + 250 + bboxes.append([(x0, y0, x1, y1)]) # 每帧 1 个实例 + return bboxes + +# ========== 主程序 ========== +def main(): + os.makedirs("outputs", exist_ok=True) + + print("=" * 60) + print("InstanceV Inference Demo") + print("=" * 60) + + # 1) 加载 Pipeline + print("\n[1/4] Loading base model...") + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + ) + + # 2) 添加 InstanceV 模块并加载权重 + print("\n[2/4] Loading InstanceV checkpoint...") + dit = pipe.dit + dim = dit.blocks[0].self_attn.q.out_features + num_heads = dit.blocks[0].self_attn.num_heads + num_layers = len(dit.blocks) + + # 添加 STAPE + if not hasattr(dit, 'stape') or dit.stape is None: + dit.stape = SharedTimestepAdaptivePromptEnhancement(dim=dim, num_heads=num_heads) + dit.stape = dit.stape.to(device=pipe.device, dtype=pipe.torch_dtype) + print(f" -> Added STAPE (dim={dim}, heads={num_heads})") + + # 添加 IMCA 到每个 block + for block_idx, block in enumerate(dit.blocks): + if not hasattr(block, 'imca') or block.imca is None: + block.imca = InstanceAwareMaskedCrossAttention(dim=dim, num_heads=num_heads) + block.imca = block.imca.to(device=pipe.device, dtype=pipe.torch_dtype) + + block.mv = torch.nn.Linear(dim, dim, bias=False) + block.mv = block.mv.to(device=pipe.device, dtype=pipe.torch_dtype) + + block.norm_imca = torch.nn.LayerNorm(dim, elementwise_affine=False) + block.norm_imca = block.norm_imca.to(device=pipe.device, dtype=pipe.torch_dtype) + + block.forward = types.MethodType(DiTBlock.forward, block) + + print(f" -> Added IMCA to {num_layers} blocks") + + # 加载 checkpoint + if os.path.exists(CHECKPOINT): + state_dict = load_file(CHECKPOINT) + loaded = 0 + for key, value in state_dict.items(): + try: + parts = key.split('.') + obj = dit + for part in parts[:-1]: + if part.isdigit(): + obj = obj[int(part)] + else: + obj = getattr(obj, part) + + param_name = parts[-1] + param = getattr(obj, param_name) + if isinstance(param, torch.nn.Parameter): + param.data.copy_(value.to(device=pipe.device, dtype=pipe.torch_dtype)) + loaded += 1 + elif isinstance(param, torch.Tensor): + param.copy_(value.to(device=pipe.device, dtype=pipe.torch_dtype)) + loaded += 1 + except: + pass + print(f" -> Loaded {loaded} parameters from checkpoint") + else: + print(f" -> [Warning] Checkpoint not found: {CHECKPOINT}") + + dit.eval() + + # 3) 准备 bbox + print("\n[3/4] Preparing bboxes...") + f_lat = (NUM_FRAMES - 1) // 4 + 1 + instance_bboxes = generate_moving_bbox(f_lat, WIDTH, HEIGHT) + print(f" -> {f_lat} frames, {len(instance_bboxes[0])} instances") + + # 4) 生成视频 + print("\n[4/4] Generating video...") + video = pipe( + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + seed=SEED, + height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, + tiled=True, + + # InstanceV + instance_prompts=INSTANCE_PROMPTS, + instance_bboxes=instance_bboxes, + saug_scale=SAUG_SCALE, + saug_drop_prob=0.0, + ) + + save_video(video, OUTPUT, fps=15, quality=5) + print(f"\n✅ Video saved to: {OUTPUT}") + + +if __name__ == "__main__": + main() + diff --git a/examples/wanvideo/model_inference/krea-realtime-video.py b/examples/wanvideo/model_inference/krea-realtime-video.py new file mode 100644 index 0000000000000000000000000000000000000000..516b04f5b531ae322d75485feeaee262b63172fd --- /dev/null +++ b/examples/wanvideo/model_inference/krea-realtime-video.py @@ -0,0 +1,25 @@ +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +# Text-to-video +video = pipe( + prompt="a cat sitting on a boat", + num_inference_steps=6, num_frames=81, + seed=0, tiled=True, + cfg_scale=1, + sigma_shift=20, +) +save_video(video, "video_krea-realtime-video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/mc_prompts_10.txt b/examples/wanvideo/model_inference/mc_prompts_10.txt new file mode 100644 index 0000000000000000000000000000000000000000..134b9ba8a234bf34f6a7b25a25725f71151e7e17 --- /dev/null +++ b/examples/wanvideo/model_inference/mc_prompts_10.txt @@ -0,0 +1,10 @@ +The video showcases a first-person perspective within the game Minecraft inside a small village storage room. The camera is completely still and the player stands in place, facing a wooden chest on the floor. The chest lid opens slowly and then closes, repeating once, while a villager can be seen through an open doorway in the background. The scene keeps the blocky Minecraft look and only the chest and villager animate. +The video showcases a first-person perspective within the game Minecraft at a village farm. The player stands perfectly still at the edge of tilled soil, looking at wheat crops in front of them. Over the clip the wheat visibly grows through stages from short green sprouts to full golden wheat, while a villager farmer walks nearby. The camera does not move. +The video showcases a first-person perspective within the game Minecraft in front of a house door. The player stands still and looks directly at an oak door and its handle. The door swings open and then closes twice, as if someone is using it, while a villager passes behind the door frame. The viewpoint remains fixed with no walking. +The video showcases a first-person perspective within the game Minecraft next to a double chest placed by a village path. The player remains stationary and the camera is locked. The double chest opens and closes and a villager briefly stops near it as if checking storage, then continues walking away. Only the chest lid and villager movement change. +The video showcases a first-person perspective within the game Minecraft inside a simple farmhouse. The player stands still looking at a wooden door and a chest beside the wall. A villager enters, the door opens inward, and the villager steps into frame while the chest lid flips open once and closes. The camera stays perfectly steady. +The video showcases a first-person perspective within the game Minecraft at a carrot field in a village. The player does not move, staring at neat rows of carrots planted in tilled soil. The carrots grow visibly taller and fuller, and a villager farmer strolls along the row. The background shows a nearby house door but the viewpoint remains static. +The video showcases a first-person perspective within the game Minecraft in a village market area. The player stands still facing a villager next to a chest and a door behind them. The villager nods and shifts slightly as if trading, while the chest opens and closes once and the door swings briefly. The camera remains fixed with no player motion. +The video showcases a first-person perspective within the game Minecraft at the entrance of a village house. The player stands in place, looking into the doorway where a chest is visible inside. A villager approaches, opens the door, pauses, then closes it, all while the camera stays completely still. The chest remains in view as a background object. +The video showcases a first-person perspective within the game Minecraft in a small greenhouse-like farm corner with a chest in the corner and a door on one side. The player stands still, watching young crops in front of them. The crops grow rapidly over the clip, and a villager briefly walks past the doorway. The camera does not pan or tilt. +The video showcases a first-person perspective within the game Minecraft on a village farm path. The player stands motionless, looking at a chest near the farm and a wooden door in the distance. The chest opens and closes once, the door opens and shuts, and the crops beside the path visibly advance in growth, while a villager walks through the scene. The viewpoint stays locked and steady. diff --git a/examples/wanvideo/model_inference/render_iground_bboxes.py b/examples/wanvideo/model_inference/render_iground_bboxes.py new file mode 100644 index 0000000000000000000000000000000000000000..fa6860b6ed4d5520e4e079728e5a53c427571eae --- /dev/null +++ b/examples/wanvideo/model_inference/render_iground_bboxes.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Render iGround video with bbox overlays for quick dataset verification. +""" + +import argparse +import json +import os +from typing import Dict, Tuple + +import imageio.v2 as imageio +import numpy as np +from PIL import Image, ImageDraw + + +def parse_args(): + parser = argparse.ArgumentParser(description="Render iGround bboxes on video") + parser.add_argument( + "--iground_jsonl", + type=str, + default="/data/rczhang/PencilFolder/data/iGround/iGround_train_set_processed.jsonl", + help="Path to iGround processed JSONL.", + ) + parser.add_argument( + "--clips_dir", + type=str, + default="/data/rczhang/PencilFolder/data/iGround/Clips/train", + help="Directory with iGround clips.", + ) + parser.add_argument("--video_id", type=str, required=True, help="Video ID to render.") + parser.add_argument("--clip_id", type=str, required=True, help="Clip ID to render.") + parser.add_argument( + "--output", + type=str, + default=None, + help="Output video path (mp4).", + ) + parser.add_argument("--max_frames", type=int, default=None, help="Limit frames for preview.") + parser.add_argument("--line_width", type=int, default=3, help="BBox line width.") + parser.add_argument("--draw_labels", action="store_true", help="Draw label text.") + return parser.parse_args() + + +def find_sample(jsonl_path: str, video_id: str, clip_id: str) -> dict: + with open(jsonl_path, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + sample = json.loads(line) + if sample.get("video_id") == video_id and str(sample.get("clip_id")) == str(clip_id): + return sample + raise ValueError(f"Sample not found: video_id={video_id}, clip_id={clip_id}") + + +def color_for_label(label: str) -> Tuple[int, int, int]: + seed = abs(hash(label)) % 360 + r = int(127 + 127 * ((seed * 3) % 100) / 100) + g = int(127 + 127 * ((seed * 7) % 100) / 100) + b = int(127 + 127 * ((seed * 11) % 100) / 100) + return r, g, b + + +def build_color_map(phrases) -> Dict[str, Tuple[int, int, int]]: + return {p: color_for_label(p) for p in phrases} + + +def draw_bboxes(img: Image.Image, labels, bboxes, color_map, line_width: int, draw_labels: bool): + draw = ImageDraw.Draw(img) + for label, bbox in zip(labels, bboxes): + if bbox is None or len(bbox) != 4: + continue + x0, y0, x1, y1 = bbox + color = color_map.get(label, (255, 0, 0)) + for w in range(line_width): + draw.rectangle([x0 - w, y0 - w, x1 + w, y1 + w], outline=color) + if draw_labels: + draw.text((x0 + 2, y0 + 2), str(label), fill=color) + + +def main(): + args = parse_args() + sample = find_sample(args.iground_jsonl, args.video_id, args.clip_id) + + clip_name = f"{sample['video_id']}_{sample['clip_id']}.mp4" + clip_path = os.path.join(args.clips_dir, clip_name) + if not os.path.isfile(clip_path): + raise FileNotFoundError(f"Clip not found: {clip_path}") + + output_path = args.output + if output_path is None: + output_path = f"{sample['video_id']}_{sample['clip_id']}_bboxes.mp4" + + phrases = sample.get("phrases", []) + color_map = build_color_map(phrases) + labels_per_frame = sample.get("labels", []) + bboxes_per_frame = sample.get("bboxes", []) + + reader = imageio.get_reader(clip_path) + fps = sample.get("fps", None) + if fps is None: + try: + fps = reader.get_meta_data().get("fps", 15) + except Exception: + fps = 15 + + writer = imageio.get_writer(output_path, fps=fps, macro_block_size=1) + + try: + max_frames = args.max_frames or len(bboxes_per_frame) + for idx, frame in enumerate(reader): + if idx >= max_frames or idx >= len(bboxes_per_frame): + break + img = Image.fromarray(frame) + labels = labels_per_frame[idx] + bboxes = bboxes_per_frame[idx] + draw_bboxes(img, labels, bboxes, color_map, args.line_width, args.draw_labels) + writer.append_data(np.array(img)) + finally: + reader.close() + writer.close() + + print(f"Saved: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/wanvideo/model_inference/render_iground_bboxes.sh b/examples/wanvideo/model_inference/render_iground_bboxes.sh new file mode 100644 index 0000000000000000000000000000000000000000..5d2b725bbccd24b07c464cbd305d21b0664371b3 --- /dev/null +++ b/examples/wanvideo/model_inference/render_iground_bboxes.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [[ $# -lt 2 ]]; then + echo "Usage: bash render_iground_bboxes.sh [output.mp4]" + exit 1 +fi + +VIDEO_ID="$1" +CLIP_ID="$2" +OUTPUT="${3:-}" + +PROJECT_ROOT="/data/rczhang/PencilFolder/DiffSynth-Studio" +PYTHON_BIN="/home/rczhang/miniconda3/envs/diffsyn/bin/python" + +IGROUND_JSONL="/data/rczhang/PencilFolder/data/iGround/iGround_train_set_processed.jsonl" +CLIPS_DIR="/data/rczhang/PencilFolder/data/iGround/Clips/train" + +ARGS=( + "${PROJECT_ROOT}/examples/wanvideo/model_inference/render_iground_bboxes.py" + --iground_jsonl "${IGROUND_JSONL}" + --clips_dir "${CLIPS_DIR}" + --video_id "${VIDEO_ID}" + --clip_id "${CLIP_ID}" + --line_width 3 +) + +if [[ -n "${OUTPUT}" ]]; then + ARGS+=(--output "${OUTPUT}") +fi + +"${PYTHON_BIN}" "${ARGS[@]}" diff --git a/examples/wanvideo/model_inference/run_instancev_batch_iground.sh b/examples/wanvideo/model_inference/run_instancev_batch_iground.sh new file mode 100644 index 0000000000000000000000000000000000000000..5f4894e21a4e369e11141bf6c429e6ab2bdada84 --- /dev/null +++ b/examples/wanvideo/model_inference/run_instancev_batch_iground.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="/data/rczhang/PencilFolder/DiffSynth-Studio" +PYTHON_BIN="/home/rczhang/miniconda3/envs/diffsyn/bin/python" + +export CUDA_VISIBLE_DEVICES=0 +export DIFFSYNTH_SKIP_DOWNLOAD=true + +OUT_DIR="/data/rczhang/PencilFolder/DiffSynth-Studio/outputs/instancev-new" +mkdir -p "${OUT_DIR}" + +META_PATH="/data/rczhang/PencilFolder/data/iGround/instancev_iground_train.jsonl" +NUM_CASES="${NUM_CASES:-20}" +WIDTH="${WIDTH:-832}" +HEIGHT="${HEIGHT:-480}" +NUM_FRAMES="${NUM_FRAMES:-81}" +NUM_INFERENCE_STEPS="${NUM_INFERENCE_STEPS:-20}" +CFG_SCALE="${CFG_SCALE:-5.0}" + +MODEL_DIR="${PROJECT_ROOT}/models" +DIT_PATH="${MODEL_DIR}/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" +TEXT_ENCODER_PATH="${MODEL_DIR}/DiffSynth-Studio/Wan-Series-Converted-Safetensors/models_t5_umt5-xxl-enc-bf16.safetensors" +VAE_PATH="${MODEL_DIR}/DiffSynth-Studio/Wan-Series-Converted-Safetensors/Wan2.1_VAE.safetensors" +TOKENIZER_PATH="${MODEL_DIR}/Wan-AI/Wan2.1-T2V-1.3B/google/umt5-xxl" + +TRAIN_DIR="$(ls -1dt "${PROJECT_ROOT}/models/train/instancev_iground_"* | head -n 1)" +CHECKPOINT="$(ls -1t "${TRAIN_DIR}"/step-*.safetensors | head -n 1)" + +export PROJECT_ROOT PYTHON_BIN OUT_DIR META_PATH NUM_CASES WIDTH HEIGHT NUM_FRAMES NUM_INFERENCE_STEPS CFG_SCALE +export DIT_PATH TEXT_ENCODER_PATH VAE_PATH TOKENIZER_PATH CHECKPOINT + +"${PYTHON_BIN}" - <<'PY' +import json +import os +import random +import subprocess +import sys + +project_root = os.environ["PROJECT_ROOT"] +python_bin = os.environ["PYTHON_BIN"] +out_dir = os.environ["OUT_DIR"] +meta_path = os.environ["META_PATH"] +num_cases = int(os.environ["NUM_CASES"]) +width = int(os.environ["WIDTH"]) +height = int(os.environ["HEIGHT"]) +num_frames = int(os.environ["NUM_FRAMES"]) +num_steps = int(os.environ["NUM_INFERENCE_STEPS"]) +cfg_scale = float(os.environ["CFG_SCALE"]) + +dit_path = os.environ["DIT_PATH"] +text_encoder_path = os.environ["TEXT_ENCODER_PATH"] +vae_path = os.environ["VAE_PATH"] +tokenizer_path = os.environ["TOKENIZER_PATH"] +checkpoint = os.environ["CHECKPOINT"] + +random.seed(42) + +samples = [] +with open(meta_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + samples.append(json.loads(line)) +random.shuffle(samples) + +selected = [] +for sample in samples: + if not sample.get("prompt") or not sample.get("instance_prompts"): + continue + selected.append(sample) + if len(selected) >= num_cases: + break + +if len(selected) < num_cases: + print(f"[Warning] Only {len(selected)} samples available.") + +f_lat = (num_frames - 1) // 4 + 1 + +def lerp(a, b, t): + return a + (b - a) * t + +def interp_bbox(start, end, t): + return [ + lerp(start[0], end[0], t), + lerp(start[1], end[1], t), + lerp(start[2], end[2], t), + lerp(start[3], end[3], t), + ] + +starts = [ + [int(0.1 * width), int(0.1 * height), int(0.45 * width), int(0.9 * height)], + [int(0.5 * width), int(0.55 * height), int(0.7 * width), int(0.78 * height)], + [int(0.0 * width), int(0.65 * height), int(1.0 * width), int(0.98 * height)], + [int(0.2 * width), int(0.2 * height), int(0.4 * width), int(0.5 * height)], + [int(0.6 * width), int(0.15 * height), int(0.85 * width), int(0.45 * height)], +] +ends = [ + [int(0.2 * width), int(0.12 * height), int(0.5 * width), int(0.9 * height)], + [int(0.55 * width), int(0.5 * height), int(0.72 * width), int(0.7 * height)], + [int(0.0 * width), int(0.6 * height), int(1.0 * width), int(0.98 * height)], + [int(0.25 * width), int(0.25 * height), int(0.45 * width), int(0.55 * height)], + [int(0.62 * width), int(0.2 * height), int(0.88 * width), int(0.48 * height)], +] + +manifest_path = os.path.join(out_dir, "cases_manifest.jsonl") +manifest = open(manifest_path, "w", encoding="utf-8") + +success = 0 +for idx, sample in enumerate(selected): + prompt = sample["prompt"] + instances = sample["instance_prompts"] + if len(instances) == 0: + continue + + bboxes = [] + for t_idx in range(f_lat): + t = t_idx / max(1, f_lat - 1) + frame = [] + for inst_id in range(len(instances)): + start = starts[inst_id % len(starts)] + end = ends[inst_id % len(ends)] + frame.append(interp_bbox(start, end, t)) + bboxes.append(frame) + + bboxes_path = os.path.join(out_dir, f"case_{idx:02d}_bboxes.json") + with open(bboxes_path, "w", encoding="utf-8") as f: + json.dump(bboxes, f) + + output_path = os.path.join(out_dir, f"case_{idx:02d}.mp4") + + record = { + "case_id": idx, + "prompt": prompt, + "instance_prompts": instances, + "bboxes_json": bboxes_path, + "output": output_path, + } + manifest.write(json.dumps(record, ensure_ascii=False) + "\n") + manifest.flush() + + cmd = [ + python_bin, + os.path.join(project_root, "examples/wanvideo/model_inference/instanceV_inference.py"), + "--checkpoint", checkpoint, + "--prompt", prompt, + "--instance_prompts", + *instances, + "--bboxes_json", bboxes_path, + "--height", str(height), + "--width", str(width), + "--num_frames", str(num_frames), + "--num_inference_steps", str(num_steps), + "--cfg_scale", str(cfg_scale), + "--seed", str(1000 + idx), + "--output", output_path, + "--dit_path", dit_path, + "--text_encoder_path", text_encoder_path, + "--vae_path", vae_path, + "--tokenizer_path", tokenizer_path, + ] + + print(f"[Case {idx:02d}] {output_path}") + try: + subprocess.run(cmd, check=True) + success += 1 + except subprocess.CalledProcessError as e: + print(f"[Case {idx:02d}] Failed: {e}", file=sys.stderr) + +manifest.close() +print(f"Done. Success: {success}/{len(selected)}") +print(f"Manifest: {manifest_path}") +PY diff --git a/examples/wanvideo/model_inference/run_instancev_inference.sh b/examples/wanvideo/model_inference/run_instancev_inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..8d2cff46114bf06d1e76ec2af205c6dc208d4861 --- /dev/null +++ b/examples/wanvideo/model_inference/run_instancev_inference.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# InstanceV 推理脚本 +# 用法: bash run_instancev_inference.sh + +cd "$(dirname "$0")/../../.." # 切换到 DiffSynth-Studio 根目录 + +# 配置 +CHECKPOINT="models/train/instancev/step-1500.safetensors" +OUTPUT_DIR="outputs/instancev" +mkdir -p "$OUTPUT_DIR" + +# GPU +export CUDA_VISIBLE_DEVICES=0 + +echo "==============================================" +echo " InstanceV Inference Examples " +echo "==============================================" + +# 示例 1: 多实例(3 个),移动 bbox +echo "" +echo "[Example 1] 多个实例同场景" +python examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "纪实摄影风格画面,绿茵茵的草地,阳光明媚。三只小动物在草地上互动。中景侧面视角。" \ + --instance_prompts \ + "一只棕黄色毛发柔软的小狗,表情欢快" \ + "一只黑白花纹的小猫,神态好奇" \ + "一个红色的飞盘,漂浮在草地上方" \ + --bboxes_json examples/wanvideo/model_inference/example_multi_moving_bbox.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.6 \ + --seed 42 \ + --output "$OUTPUT_DIR/multi_instances_animals.mp4" + +# 示例 2: 多个实例(4 个),移动 bbox +echo "" +echo "[Example 2] 四个人物对话" +python examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "电影画面,室内场景,四个人围成半圈站着交谈,背景是温馨的客厅。中景正面视角。" \ + --instance_prompts \ + "一位穿着蓝色衬衫的年轻男性,短发,表情认真" \ + "一位穿着红色连衣裙的年轻女性,长发,微笑着" \ + "一位穿着灰色毛衣的中年男性,戴眼镜,神情专注" \ + "一位穿着绿色外套的年轻女性,马尾,表情友好" \ + --bboxes_json examples/wanvideo/model_inference/example_four_moving_bbox.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.5 \ + --seed 123 \ + --output "$OUTPUT_DIR/four_people_talking.mp4" + +# 示例 3: 两个实例交错移动(运动幅度略大) +echo "" +echo "[Example 3] 两个运动员交错穿行" +python examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "运动纪录片风格,傍晚的滑板公园,地面有反光,中景跟拍。两位运动员从画面两侧向中间穿行交错。" \ + --instance_prompts \ + "穿黑色连帽衫的滑板男青年,动作敏捷" \ + "穿白色T恤的轮滑女青年,动作轻快" \ + --bboxes_json examples/wanvideo/model_inference/example_two_crossing_big_motion.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.6 \ + --seed 7 \ + --output "$OUTPUT_DIR/two_crossing_athletes.mp4" + +# 示例 4: 三个实例对角线移动(运动幅度略大) +echo "" +echo "[Example 4] 三个主体对角线移动" +python examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "清晨公园航拍感画面,草地和步道清晰可见,中景侧俯视。三个主体朝不同方向移动,画面动感明显。" \ + --instance_prompts \ + "一架银灰色小型无人机,机身反光" \ + "一位穿橙色运动服的慢跑者,动作连贯" \ + "一只棕色小狗,四肢有力,奔跑姿态" \ + --bboxes_json examples/wanvideo/model_inference/example_three_diagonal_big_motion.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.6 \ + --seed 21 \ + --output "$OUTPUT_DIR/three_diagonal_motion.mp4" + +# 示例 5: 单实例大幅度横向扫过(运动幅度略大) +echo "" +echo "[Example 5] 单个主体大幅度扫过" +python examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "电影质感街景,白天,路面略有水渍反光。一个红色跑车从画面左侧大幅度驶向右侧,镜头稳定。" \ + --instance_prompts \ + "一辆红色跑车,车身线条流畅" \ + --bboxes_json examples/wanvideo/model_inference/example_single_sweep_big_motion.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.55 \ + --seed 88 \ + --output "$OUTPUT_DIR/single_car_sweep.mp4" + +echo "" +echo "==============================================" +echo "完成!视频保存在: $OUTPUT_DIR/" +echo "==============================================" diff --git a/examples/wanvideo/model_inference/run_instancev_inference_extra.sh b/examples/wanvideo/model_inference/run_instancev_inference_extra.sh new file mode 100644 index 0000000000000000000000000000000000000000..086bf753ee5dc3de146eb82616a01733bf6e366f --- /dev/null +++ b/examples/wanvideo/model_inference/run_instancev_inference_extra.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run from repo root (DiffSynth-Studio/) +cd "$(dirname "$0")/../../.." + +export PYTHONPATH="$(pwd)${PYTHONPATH:+:$PYTHONPATH}" + +if [[ "${CONDA_DEFAULT_ENV:-}" == "diffsyn" ]]; then + PYTHON_BIN="python" +else + PYTHON_BIN="conda run -n diffsyn python" +fi + +CHECKPOINT="${CHECKPOINT:-models/train/instancev/step-1500.safetensors}" +OUTPUT_DIR="${OUTPUT_DIR:-outputs/instancev}" +mkdir -p "$OUTPUT_DIR" + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" +TIMESTAMP="$(date +"%Y%m%d_%H%M%S")" + +echo "==============================================" +echo " InstanceV Extra Inference Examples (5) " +echo "==============================================" + +# Example 6: two scooters crossing +$PYTHON_BIN examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "夜晚城市街道,霓虹灯映在湿润路面,中景稳定跟拍。两辆小摩托从画面两侧交错穿行。" \ + --instance_prompts \ + "蓝色小摩托,骑手戴黑色头盔,动作迅捷" \ + "黄色小摩托,骑手戴白色头盔,动作稳健" \ + --bboxes_json examples/wanvideo/model_inference/example_two_scooters_crossing.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.6 \ + --seed 101 \ + --output "$OUTPUT_DIR/two_scooters_crossing_${TIMESTAMP}.mp4" + +# Example 7: two students + drone +$PYTHON_BIN examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "安静图书馆大厅,暖色灯光,镜头固定。两位学生在书架前相向走近,上方有一架小型无人机缓慢漂浮。" \ + --instance_prompts \ + "穿浅灰毛衣的学生,背着书包,步伐轻缓" \ + "穿深蓝外套的学生,手拿书本,表情专注" \ + "小型白色无人机,机身轻轻晃动" \ + --bboxes_json examples/wanvideo/model_inference/example_two_students_drone.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.55 \ + --seed 202 \ + --output "$OUTPUT_DIR/two_students_drone_${TIMESTAMP}.mp4" + +# Example 8: four pigeons orbiting +$PYTHON_BIN examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "晴朗城市广场,中央喷泉,浅景深,中景平视。四只鸽子绕着喷泉做环形飞行。" \ + --instance_prompts \ + "一只灰白鸽子,翅膀快速扇动" \ + "一只深灰鸽子,飞行更平稳" \ + "一只带斑点的鸽子,飞行略有起伏" \ + "一只浅灰鸽子,转弯灵活" \ + --bboxes_json examples/wanvideo/model_inference/example_four_pigeons_orbit.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.5 \ + --seed 303 \ + --output "$OUTPUT_DIR/four_pigeons_orbit_${TIMESTAMP}.mp4" + +# Example 9: boat + seagull +$PYTHON_BIN examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "海港清晨,薄雾,光线柔和,中景固定镜头。一艘渔船缓慢驶过,空中有海鸥掠过。" \ + --instance_prompts \ + "木质渔船,船身略有反光,航行平稳" \ + "白色海鸥,翅膀展开,掠过天空" \ + --bboxes_json examples/wanvideo/model_inference/example_boat_seagull.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.55 \ + --seed 404 \ + --output "$OUTPUT_DIR/boat_seagull_${TIMESTAMP}.mp4" + +# Example 10: deer approaching +$PYTHON_BIN examples/wanvideo/model_inference/instanceV_inference.py \ + --checkpoint "$CHECKPOINT" \ + --prompt "森林小径,清晨逆光,镜头稳定。一只鹿从远处缓缓走近镜头,树影轻晃。" \ + --instance_prompts \ + "一只棕色鹿,步态缓慢,耳朵轻动" \ + --bboxes_json examples/wanvideo/model_inference/example_deer_approach.json \ + --height 480 --width 832 --num_frames 81 \ + --saug_scale 0.6 \ + --seed 505 \ + --output "$OUTPUT_DIR/deer_approach_${TIMESTAMP}.mp4" + +echo "==============================================" +echo "完成!视频保存在: $OUTPUT_DIR/" +echo "==============================================" diff --git a/examples/wanvideo/model_inference/run_instancev_inference_local.sh b/examples/wanvideo/model_inference/run_instancev_inference_local.sh new file mode 100644 index 0000000000000000000000000000000000000000..2ec508b441843735d65835bc7d923792787479e5 --- /dev/null +++ b/examples/wanvideo/model_inference/run_instancev_inference_local.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="/data/rczhang/PencilFolder/DiffSynth-Studio" +PYTHON_BIN="/home/rczhang/miniconda3/envs/diffsyn/bin/python" + +export CUDA_VISIBLE_DEVICES=0 +export DIFFSYNTH_SKIP_DOWNLOAD=true + +MODEL_DIR="${PROJECT_ROOT}/models" +DIT_PATH="${MODEL_DIR}/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" +TEXT_ENCODER_PATH="${MODEL_DIR}/DiffSynth-Studio/Wan-Series-Converted-Safetensors/models_t5_umt5-xxl-enc-bf16.safetensors" +VAE_PATH="${MODEL_DIR}/DiffSynth-Studio/Wan-Series-Converted-Safetensors/Wan2.1_VAE.safetensors" +TOKENIZER_PATH="${MODEL_DIR}/Wan-AI/Wan2.1-T2V-1.3B/google/umt5-xxl" + +TRAIN_DIR="$(ls -1dt "${PROJECT_ROOT}/models/train/instancev_iground_"* | head -n 1)" +CHECKPOINT="$(ls -1t "${TRAIN_DIR}"/step-*.safetensors | head -n 1)" + +OUT_DIR="${PROJECT_ROOT}/outputs" +mkdir -p "${OUT_DIR}" +OUTPUT="${OUT_DIR}/instancev_iground_infer_$(date +%Y%m%d_%H%M%S).mp4" + +PROMPT=${PROMPT:-"a woman pours coffee into a cup on a table"} +INSTANCE_PROMPTS=${INSTANCE_PROMPTS:-"woman;coffee cup;table"} +WIDTH=${WIDTH:-832} +HEIGHT=${HEIGHT:-480} +NUM_FRAMES=${NUM_FRAMES:-81} +NUM_INFERENCE_STEPS=${NUM_INFERENCE_STEPS:-25} +CFG_SCALE=${CFG_SCALE:-5.0} + +BBOX_JSON="${OUT_DIR}/instancev_bboxes_$(date +%Y%m%d_%H%M%S).json" + +export INSTANCE_PROMPTS WIDTH HEIGHT NUM_FRAMES BBOX_JSON + +"${PYTHON_BIN}" - <<'PY' +import json +import os + +prompt_instances = os.environ["INSTANCE_PROMPTS"].split(";") +width = int(os.environ["WIDTH"]) +height = int(os.environ["HEIGHT"]) +num_frames = int(os.environ["NUM_FRAMES"]) +f_lat = (num_frames - 1) // 4 + 1 + +def lerp(a, b, t): + return a + (b - a) * t + +def interp_bbox(start, end, t): + return [ + lerp(start[0], end[0], t), + lerp(start[1], end[1], t), + lerp(start[2], end[2], t), + lerp(start[3], end[3], t), + ] + +starts = [ + [int(0.1 * width), int(0.1 * height), int(0.45 * width), int(0.9 * height)], + [int(0.5 * width), int(0.55 * height), int(0.65 * width), int(0.75 * height)], + [int(0.0 * width), int(0.65 * height), int(1.0 * width), int(0.98 * height)], +] +ends = [ + [int(0.2 * width), int(0.12 * height), int(0.5 * width), int(0.9 * height)], + [int(0.55 * width), int(0.5 * height), int(0.7 * width), int(0.7 * height)], + [int(0.0 * width), int(0.6 * height), int(1.0 * width), int(0.98 * height)], +] + +bboxes = [] +for i in range(f_lat): + t = i / max(1, (f_lat - 1)) + frame = [] + for idx in range(len(prompt_instances)): + start = starts[idx % len(starts)] + end = ends[idx % len(ends)] + frame.append(interp_bbox(start, end, t)) + bboxes.append(frame) + +with open(os.environ["BBOX_JSON"], "w") as f: + json.dump(bboxes, f) +print(f"Wrote bboxes: {os.environ['BBOX_JSON']}") +PY + +IFS=';' read -r -a INSTANCE_ARR <<< "${INSTANCE_PROMPTS}" +INSTANCE_ARGS=() +for p in "${INSTANCE_ARR[@]}"; do + INSTANCE_ARGS+=("$p") +done + +"${PYTHON_BIN}" "${PROJECT_ROOT}/examples/wanvideo/model_inference/instanceV_inference.py" \ + --checkpoint "${CHECKPOINT}" \ + --prompt "${PROMPT}" \ + --instance_prompts "${INSTANCE_ARGS[@]}" \ + --bboxes_json "${BBOX_JSON}" \ + --height "${HEIGHT}" \ + --width "${WIDTH}" \ + --num_frames "${NUM_FRAMES}" \ + --num_inference_steps "${NUM_INFERENCE_STEPS}" \ + --cfg_scale "${CFG_SCALE}" \ + --output "${OUTPUT}" \ + --dit_path "${DIT_PATH}" \ + --text_encoder_path "${TEXT_ENCODER_PATH}" \ + --vae_path "${VAE_PATH}" \ + --tokenizer_path "${TOKENIZER_PATH}" + +echo "Output saved to: ${OUTPUT}" diff --git a/examples/wanvideo/model_inference/test_text_sim.py b/examples/wanvideo/model_inference/test_text_sim.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a6e711babbb39245a904027da6290851f3bda2 --- /dev/null +++ b/examples/wanvideo/model_inference/test_text_sim.py @@ -0,0 +1,157 @@ +import argparse +from pathlib import Path +import torch +import torch.nn.functional as F +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[3] + +# ========================================== +# 辅助函数 +# ========================================== +def get_sentence_embedding(pipe, text, device): + tokenized = pipe.tokenizer( + text, max_length=512, padding="max_length", truncation=True, return_tensors="pt" + ) + if isinstance(tokenized, torch.Tensor): + input_ids = tokenized.to(device) + attention_mask = (input_ids != 0).long().to(device) + else: + input_ids = tokenized.input_ids.to(device) + attention_mask = tokenized.attention_mask.to(device) + + try: + encoder_outputs = pipe.text_encoder(input_ids, attention_mask) + except: + encoder_outputs = pipe.text_encoder([text], device=device) + + if hasattr(encoder_outputs, "last_hidden_state"): + embeddings = encoder_outputs.last_hidden_state + elif isinstance(encoder_outputs, (tuple, list)): + embeddings = encoder_outputs[0] + else: + embeddings = encoder_outputs + + mask_expanded = attention_mask.unsqueeze(-1).float() + masked_embeddings = embeddings * mask_expanded + sum_embeddings = torch.sum(masked_embeddings, dim=1) + sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) + return sum_embeddings / sum_mask + +def slerp(v0, v1, t): + v0 = v0.float() + v1 = v1.float() + v0_norm = v0 / torch.norm(v0, dim=-1, keepdim=True) + v1_norm = v1 / torch.norm(v1, dim=-1, keepdim=True) + dot = torch.sum(v0_norm * v1_norm, dim=-1) + dot = torch.clamp(dot, -1.0, 1.0) + theta = torch.acos(dot) + sin_theta = torch.sin(theta) + if sin_theta < 1e-6: return (1-t)*v0 + t*v1 + w0 = torch.sin((1.0 - t) * theta) / sin_theta + w1 = torch.sin(t * theta) / sin_theta + return w0 * v0 + w1 * v1 + +# ========================================== +# 主程序 +# ========================================== +def main(): + root_dir = _repo_root() + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--model_dir", type=Path, default=(root_dir / "models" / "Wan-AI" / "Wan2.1-T2V-1.3B")) + args = parser.parse_args() + + print(f"Loading model from {args.model_dir}...") + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, device=args.device, + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + ) + + # 目标:半蹲 + target_text = "人 在 半蹲" + target_vec = get_sentence_embedding(pipe, target_text, args.device) + + print(f"\n🎯 目标概念: '{target_text}'\n" + "="*50) + + # ------------------------------------------------- + # 方法 1: SLERP (基准线) + # ------------------------------------------------- + t1 = "人 在 蹲" + t2 = "人 在 站" + v1 = get_sentence_embedding(pipe, t1, args.device) + v2 = get_sentence_embedding(pipe, t2, args.device) + # 取最佳比例 0.4 + vec_slerp = slerp(v1, v2, 0.4) + sim_slerp = F.cosine_similarity(vec_slerp, target_vec).item() + + # ------------------------------------------------- + # 方法 2: 语言融合 (Prompt Concatenation) + # ------------------------------------------------- + # 直接告诉模型两个状态都有 + text_concat = "人 保持 在 站 和 蹲 之间 的 姿势" + vec_concat = get_sentence_embedding(pipe, text_concat, args.device) + sim_concat = F.cosine_similarity(vec_concat, target_vec).item() + + # ------------------------------------------------- + # 方法 3: 语义加法 (Vector Arithmetic) + # ------------------------------------------------- + # 站立 + 弯曲膝盖 = 半蹲? + t_base = "人 在 站" + t_modifier = "弯曲 膝盖" + v_base = get_sentence_embedding(pipe, t_base, args.device) + v_mod = get_sentence_embedding(pipe, t_modifier, args.device) + + # 将“弯曲膝盖”加到“站”上 (权重需要调整,假设 0.6) + vec_arithmetic = v_base + v_mod * 0.6 + # 归一化一下,防止模长爆炸 + vec_arithmetic = vec_arithmetic / vec_arithmetic.norm() * target_vec.norm() + sim_arithmetic = F.cosine_similarity(vec_arithmetic, target_vec).item() + + # ------------------------------------------------- + # 方法 4: 物理描述 (LLM Style) + # ------------------------------------------------- + # 描述动作细节,而不是动作名称 + text_phys = "人 膝盖 弯曲 90度 保持 平衡" + vec_phys = get_sentence_embedding(pipe, text_phys, args.device) + sim_phys = F.cosine_similarity(vec_phys, target_vec).item() + + # ------------------------------------------------- + # 结果展示 + # ------------------------------------------------- + print(f"{'Method':<25} | {'Text / Formula':<40} | {'Similarity':<10}") + print("-" * 80) + print(f"{'1. SLERP (Math)':<25} | {'(蹲 * 0.6 + 站 * 0.4)':<40} | {sim_slerp:.4f}") + print(f"{'2. Language Mix':<25} | {'{text_concat}':<40} | {sim_concat:.4f}") + print(f"{'3. Vector Add':<25} | {'站 + (弯曲膝盖 * 0.6)':<40} | {sim_arithmetic:.4f}") + print(f"{'4. Physical Desc':<25} | {'{text_phys}':<40} | {sim_phys:.4f}") + print("=" * 80) + + # 找出冠军 + scores = { + "SLERP": sim_slerp, + "Lang Mix": sim_concat, + "Vector Add": sim_arithmetic, + "Physical": sim_phys + } + winner = max(scores, key=scores.get) + print(f"🏆 获胜策略: {winner} (Sim: {scores[winner]:.4f})") + + print("\n💡 分析:") + if winner == "SLERP": + print(" 数学插值竟然赢了?说明 T5 对这几个词的拓扑结构非常线性。") + elif winner == "Physical": + print(" 描述性文本获胜。这说明 T5 更理解'视觉状态的描述',而不是抽象的动作动词。") + elif winner == "Vector Add": + print(" 语义加法获胜。说明'站'和'半蹲'的区别主要就在于'膝盖弯曲'这个特征。") + elif winner == "Lang Mix": + print(" 自然语言融合获胜。说明直接让模型理解'之间'这个词比数学计算更准。") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/wanvideo/model_inference_low_vram/LongCat-Video.py b/examples/wanvideo/model_inference_low_vram/LongCat-Video.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1c4faf3864b605329f7c43331842ed99133f86 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/LongCat-Video.py @@ -0,0 +1,46 @@ +import torch +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.", + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + seed=0, tiled=True, num_frames=93, + cfg_scale=2, sigma_shift=1, +) +save_video(video, "video_1_LongCat-Video.mp4", fps=15, quality=5) + +# Video-continuation (The number of frames in `longcat_video` should be 4n+1.) +longcat_video = video[-17:] +video = pipe( + prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.", + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + seed=1, tiled=True, num_frames=93, + cfg_scale=2, sigma_shift=1, + longcat_video=longcat_video, +) +save_video(video, "video_2_LongCat-Video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py b/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..b1632b4797adcdf79fa935b485ea49cf3e4b24d5 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Video-As-Prompt-Wan2.1-14B.py @@ -0,0 +1,62 @@ +import torch +import PIL +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download +from typing import List + + +# This model doesn't support fine-grained VRAM Management due to its special architecture. +# Only CPU Offload is supported. +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="wanvap/*", local_dir="data/example_video_dataset") +ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4' +target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg' + +def select_frames(video_frames, num): + idx = torch.linspace(0, len(video_frames) - 1, num).long().tolist() + return [video_frames[i] for i in idx] + +image = Image.open(target_image_path).convert("RGB") +ref_video = VideoData(ref_video_path, height=480, width=832) +ref_frames = select_frames(ref_video, num=49) + +vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery." +prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent." +negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + input_image=image, + seed=42, tiled=True, + height=480, width=832, + num_frames=49, + vap_video=ref_frames, + vap_prompt=vap_prompt, + negative_vap_prompt=negative_prompt, +) +save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8f5d18f7911cb817c7a4901f35d47851a9432e --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-1.3b-speedcontrol-v1.py @@ -0,0 +1,45 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=0 +) +save_video(video, "video_slow_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=100 +) +save_video(video, "video_fast_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000000000000000000000000000000000000..5af7fb7ca1b4afb20d67963fea5dbedb0934c9f0 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] +) + +# First and last frame to video +video = pipe( + prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), + end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), + seed=0, tiled=True, + height=960, width=960, num_frames=33, + sigma_shift=16, +) +save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..500db9f81a75f19bf40eec6f8b4f5b3bb0a90b59 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,45 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..d3533ba60156d0722ba469e88b4127c6f7ea7a06 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-1.3B-InP.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa26f020ee26109fd47ec4de4c814407f3152b7 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,45 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..c56dbffbeae7ebd5d1a683cd46ca7d7b4266f47a --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py new file mode 100644 index 0000000000000000000000000000000000000000..658c0ebc58b5d06bd59b2e8e753f067f6a124c99 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py @@ -0,0 +1,55 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..20eb6871236241fada69697a61d2f1e1325c31b3 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..fc95d4c5bf6b46bf4c52c7ed5f17a63a6c88819f --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-1.3B-InP.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py new file mode 100644 index 0000000000000000000000000000000000000000..37434ba16223612f9eb961528d990b023e52be2f --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control-Camera.py @@ -0,0 +1,55 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb66f3e34e9a2ae59a3b4ed662fccb79a9348ff --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ae5880fed7d7e065affa6f790b2e9baf9f4e73 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-Fun-V1.1-14B-InP.py @@ -0,0 +1,47 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000000000000000000000000000000000000..b14412c8e4839092ed81a868ba89a85a0a756ff8 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,45 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec01b4e3b752dfdca1f7288e788e53cdd3b1d45 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,46 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + height=720, width=1280, +) +save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py new file mode 100644 index 0000000000000000000000000000000000000000..6f90d68aec798fd7f0d45ea06ba618e77f1cf722 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-1.3B.py @@ -0,0 +1,45 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video_1_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) + +# Video-to-video +video = VideoData("video_1_Wan2.1-T2V-1.3B.mp4", height=480, width=832) +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_video=video, denoising_strength=0.7, + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..68555b56d051079ed1737b6cbb6236ed1f037533 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-T2V-14B.py @@ -0,0 +1,35 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py new file mode 100644 index 0000000000000000000000000000000000000000..5f118198cfbfa48d1637bf885052286e814a2426 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B-Preview.py @@ -0,0 +1,63 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py new file mode 100644 index 0000000000000000000000000000000000000000..60bd8c0e551db14e2e6d08ab80bf79d03315e5bd --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-1.3B.py @@ -0,0 +1,64 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa795e36378273c1fc56ecc7a996cfc1923ff26 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.1-VACE-14B.py @@ -0,0 +1,65 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.1-VACE-14B.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.1-VACE-14B.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.1-VACE-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..180482c14e489a87324a1337adcb3bd7510053c7 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py @@ -0,0 +1,74 @@ +import torch +from PIL import Image +from diffsynth.core import load_state_dict +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download, snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/wan/animate/*", +) + +# Animate +input_image = Image.open("data/examples/wan/animate/animate_input_image.png") +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + num_frames=81, height=720, width=1280, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5) + +# Replace +snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B") +lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.bfloat16, device="cuda")["state_dict"] +lora_state_dict = {i: lora_state_dict[i].to(torch.bfloat16) for i in lora_state_dict} +pipe.load_lora(pipe.dit, state_dict=lora_state_dict) +input_image = Image.open("data/examples/wan/animate/replace_input_image.png") +animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4] +animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4] +animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + animate_inpaint_video=animate_inpaint_video, + animate_mask_video=animate_mask_video, + num_frames=81, height=720, width=1280, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video_2_Wan2.2-Animate-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py new file mode 100644 index 0000000000000000000000000000000000000000..760bd7e271739af44127e582cc862f2c39902f2c --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control-Camera.py @@ -0,0 +1,55 @@ +import torch +from diffsynth.utils.data import save_video,VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py new file mode 100644 index 0000000000000000000000000000000000000000..df92f7d2306f8f0ce2c26453e8c64f9f73581999 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-Control.py @@ -0,0 +1,46 @@ +import torch +from diffsynth.utils.data import save_video,VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-Fun-A14B-Control.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py new file mode 100644 index 0000000000000000000000000000000000000000..34f8cfaa99e76e3c784c8f14c4d9dec65d0a8355 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-Fun-A14B-InP.py @@ -0,0 +1,46 @@ +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video_Wan2.2-Fun-A14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc8a6715e5cf03840f2b83feb59b708eba2c5b9 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-I2V-A14B.py @@ -0,0 +1,44 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) + +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + switch_DiT_boundary=0.9, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B.py new file mode 100644 index 0000000000000000000000000000000000000000..14be8c919568978b0d3303099f33a2c6c627c19b --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B.py @@ -0,0 +1,84 @@ +# This script can generate a single video clip. +# If you need generate long videos, please refer to `Wan2.2-S2V-14B_multi_clips.py`. +import torch +from PIL import Image +import librosa +from diffsynth.utils.data import VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*" +) + +num_frames = 81 # 4n+1 +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) +# s2v audio input, recommend 16kHz sampling rate +audio_path = 'data/example_video_dataset/wans2v/sing.MP3' +input_audio, sample_rate = librosa.load(audio_path, sr=16000) + +# Speech-to-video +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + num_inference_steps=40, +) +save_video_with_audio(video[1:], "video_1_Wan2.2-S2V-14B.mp4", audio_path, fps=16, quality=5) + +# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. +pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' +pose_video = VideoData(pose_video_path, height=height, width=width) + +# Speech-to-video with pose +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + s2v_pose_video=pose_video, + num_inference_steps=40, +) +save_video_with_audio(video[1:], "video_2_Wan2.2-S2V-14B.mp4", audio_path, fps=16, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py new file mode 100644 index 0000000000000000000000000000000000000000..c1995cf644955889c0d7a7330b7e234d25d60b67 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-S2V-14B_multi_clips.py @@ -0,0 +1,128 @@ +import torch +from PIL import Image +import librosa +from diffsynth.utils.data import VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V +from modelscope import dataset_snapshot_download + + +def speech_to_video( + prompt, + input_image, + audio_path, + negative_prompt="", + num_clip=None, + audio_sample_rate=16000, + pose_video_path=None, + infer_frames=80, + height=448, + width=832, + num_inference_steps=40, + fps=16, # recommend fixing fps as 16 for s2v + motion_frames=73, # hyperparameter of wan2.2-s2v + save_path=None, +): + # s2v audio input, recommend 16kHz sampling rate + input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate) + # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. + pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None + + audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose( + pipe=pipe, + input_audio=input_audio, + audio_sample_rate=sample_rate, + s2v_pose_video=pose_video, + num_frames=infer_frames + 1, + height=height, + width=width, + fps=fps, + ) + num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat + print(f"Generating {num_repeat} video clips...") + motion_videos = [] + video = [] + for r in range(num_repeat): + s2v_pose_latents = pose_latents[r] if pose_latents is not None else None + current_clip = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt=negative_prompt, + seed=0, + num_frames=infer_frames + 1, + height=height, + width=width, + audio_embeds=audio_embeds[r], + s2v_pose_latents=s2v_pose_latents, + motion_video=motion_videos, + num_inference_steps=num_inference_steps, + ) + current_clip = current_clip[-infer_frames:] + if r == 0: + current_clip = current_clip[3:] + overlap_frames_num = min(motion_frames, len(current_clip)) + motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:] + video.extend(current_clip) + save_video_with_audio(video, save_path, audio_path, fps=16, quality=5) + print(f"processed the {r+1}th clip of total {num_repeat} clips.") + return video + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*", +) + +infer_frames = 80 # 4n +height = 448 +width = 832 + +prompt = "a person is singing" +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) + +video_with_audio = speech_to_video( + prompt=prompt, + input_image=input_image, + audio_path='data/example_video_dataset/wans2v/sing.MP3', + negative_prompt=negative_prompt, + pose_video_path='data/example_video_dataset/wans2v/pose.mp4', + save_path="video_full_Wan2.2-S2V-14B.mp4", + infer_frames=infer_frames, + height=height, + width=width, +) +# num_clip means generating only the first n clips with n * infer_frames frames. +video_with_audio_pose = speech_to_video( + prompt=prompt, + input_image=input_image, + audio_path='data/example_video_dataset/wans2v/sing.MP3', + negative_prompt=negative_prompt, + pose_video_path='data/example_video_dataset/wans2v/pose.mp4', + save_path="video_clip_2_Wan2.2-S2V-14B.mp4", + num_clip=2 +) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5155c71b160a7457965a53a5aee69e6bb6075c --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-T2V-A14B.py @@ -0,0 +1,35 @@ +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000000000000000000000000000000000000..34cbfae93d6980d6f162b4d6ec93021c4e277120 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-TI2V-5B.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1248, + num_frames=121, +) +save_video(video, "video_1_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) + +# Image-to-video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1248, + input_image=input_image, + num_frames=121, +) +save_video(video, "video_2_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py new file mode 100644 index 0000000000000000000000000000000000000000..3474b01d9236b9605acc8c88f822ac71e91cb7e8 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py @@ -0,0 +1,65 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video_1_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_2_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video_3_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py b/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc7a3f1e667a27fcf902cdbed942922dfc54281 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/krea-realtime-video.py @@ -0,0 +1,36 @@ +import torch +from diffsynth.utils.data import save_video +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +# Text-to-video +video = pipe( + prompt="a cat sitting on a boat", + num_inference_steps=6, num_frames=81, + seed=0, tiled=True, + cfg_scale=1, + sigma_shift=20, +) +save_video(video, "video_krea-realtime-video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/INSTANCEV_IGROUND_PIPELINE.md b/examples/wanvideo/model_training/INSTANCEV_IGROUND_PIPELINE.md new file mode 100644 index 0000000000000000000000000000000000000000..aa427743958e9fc0aeea7e69b65a699b812c185c --- /dev/null +++ b/examples/wanvideo/model_training/INSTANCEV_IGROUND_PIPELINE.md @@ -0,0 +1,100 @@ +# InstanceV iGround Training Pipeline (Summary) + +This document describes how iGround processed JSONL is aligned to the current InstanceV training code. + +## Goal + +- Convert iGround processed clips into InstanceV metadata format. +- Generate per-instance masks from bboxes (per frame). +- Train at instance level with the existing InstanceV trainer. +- Output checkpoints to a new, timestamped directory. + +## Data Inputs + +- iGround processed JSONL: + - `/data/rczhang/PencilFolder/data/iGround/iGround_train_set_processed.jsonl` +- iGround clip videos: + - `/data/rczhang/PencilFolder/data/iGround/Clips/train/*.mp4` + +## Metadata Output + +Generated JSONL: + +- `/data/rczhang/PencilFolder/data/iGround/instancev_iground_train.jsonl` + +Per-line schema: + +```json +{ + "video": "iGround/Clips/train/_.mp4", + "prompt": "", + "instance_prompts": ["phrase1", "phrase2", ...], + "instance_mask_dirs": [ + {"mask_dir": "/abs/path/to/masks", "instance_id": 0, "num_frames": 49}, + ... + ] +} +``` + +Notes: +- `video` is stored as a path relative to dataset base path (`/data/rczhang/PencilFolder/data`). +- `instance_prompts` comes from iGround `phrases` filtered by visibility in labels. +- `instance_mask_dirs` is per instance and includes original frame count. + +## Mask Generation + +Script: +- `DiffSynth-Studio/examples/wanvideo/model_training/prepare_instancev_iground.py` + +Process: +- For each clip, parse per-frame `labels` and `bboxes`. +- For each instance phrase, create a binary mask per frame. +- Save masks as PNGs: + - `/_No..png` + +Mask root: +- `/data/rczhang/PencilFolder/data/iGround/InstanceMasks/train` + +Failure handling: +- Skip missing video files. +- Skip unreadable videos (ffmpeg metadata error). +- Skip samples with no visible instances (if any). + +## Trainer Alignment + +Loader changes (InstanceV): +- `LoadInstanceMasks` in `train_instancev.py` now: + - Uses per-sample `num_frames` from metadata. + - Adjusts frame count to match model time stride. + - Center-crops/resizes masks to match video preprocessing size. + +Training launcher: +- `run_instancev_training.sh` uses: + - iGround metadata output + - iGround mask root + - Single GPU (`CUDA_VISIBLE_DEVICES=0`) and `--num_processes 1` + - Output directory with timestamp + +## Current Training Command + +Launcher: +- `DiffSynth-Studio/examples/wanvideo/model_training/run_instancev_training.sh` + +It runs: +- `prepare_instancev_iground.py` (metadata + masks) +- `train_instancev.py` (InstanceV training) + +## Output + +Checkpoints and logs: +- `DiffSynth-Studio/models/train/instancev_iground_/` + +Wandb: +- Project: `instancev-training` +- Run name: `instancev_` + +## Known Limitations + +- If a clip is corrupt or missing `moov` atom, it is skipped during metadata build. +- Single GPU training may still need memory tuning (if OOM). + diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_4fps_640x320.mp4 b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_4fps_640x320.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a6706af006883b3718175b121c304603db6385b8 --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_4fps_640x320.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:443bfe7058a01aebc0c141593c5153a729c80636e78eafd57dfafb1a047cbcc8 +size 324850 diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_1280x720.mp4 b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_1280x720.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bd61ebd8fbc6094c0fd09d0b69854c68153c8f0f --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_1280x720.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d867d85908f3dc434eced4029df560d7c6335b52fd18588ed42230e7bcce20c0 +size 1228430 diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_640x320.mp4 b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_640x320.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ff839184655533c9355c78a7954a3cddaff52d8a --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_640x320.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c883827ae54436fd265646807edc9706b9c377a6c3409f72fcd7ee8cbbdcab92 +size 350340 diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_64x64.mp4 b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_64x64.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..91bfbec31c8d5edaecc7937fa1cf13d8600bfc83 Binary files /dev/null and b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_64x64.mp4 differ diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_NonexNone.mp4 b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_NonexNone.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bd61ebd8fbc6094c0fd09d0b69854c68153c8f0f --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_NonexNone.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d867d85908f3dc434eced4029df560d7c6335b52fd18588ed42230e7bcce20c0 +size 1228430 diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_8fps_448x256.mp4 b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_8fps_448x256.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e6c55d3db60f648415dcf765b7d31f2cef07c141 --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_8fps_448x256.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4364edceb44835e48f616316c0f55f80e935b4a51af60d167280ff7e08025c3f +size 229433 diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_class_ids.json b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_class_ids.json new file mode 100644 index 0000000000000000000000000000000000000000..e73306989a3fc47b526975fa4afb48525bfb3403 --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_class_ids.json @@ -0,0 +1 @@ +[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] \ No newline at end of file diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_ids.json b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_ids.json new file mode 100644 index 0000000000000000000000000000000000000000..e73306989a3fc47b526975fa4afb48525bfb3403 --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_ids.json @@ -0,0 +1 @@ +[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] \ No newline at end of file diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_masks.npy b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_masks.npy new file mode 100644 index 0000000000000000000000000000000000000000..61261641b584cc1ac08c9bd63cfba37b3ee65f3c --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_masks.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80ad29bbc594c039bac4497290875d5e6b741f722c5d7c6c6998b0325532f911 +size 3314483328 diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_state_ids.json b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_state_ids.json new file mode 100644 index 0000000000000000000000000000000000000000..b7a381a6fde955f134e97c880212f12476e2147a --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/egg_instance_state_ids.json @@ -0,0 +1 @@ +[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]] \ No newline at end of file diff --git a/examples/wanvideo/model_training/egg_statemachine_dataset/metadata.json b/examples/wanvideo/model_training/egg_statemachine_dataset/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..32f89f85d2c9200ae01b58a8dab13c5d6595785e --- /dev/null +++ b/examples/wanvideo/model_training/egg_statemachine_dataset/metadata.json @@ -0,0 +1,10 @@ +[ + { + "video": "egg_8fps_448x256.mp4", + "prompt": "an egg", + "instance_class_ids": "egg_instance_class_ids.json", + "instance_state_ids": "egg_instance_state_ids.json", + "instance_ids": "egg_instance_ids.json", + "instance_masks": "egg_instance_masks.npy" + } +] \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/LongCat-Video.sh b/examples/wanvideo/model_training/full/LongCat-Video.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d8902e6712cb9bde92bad188dd4336591e36ca5 --- /dev/null +++ b/examples/wanvideo/model_training/full/LongCat-Video.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LongCat-Video_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh b/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..6be4d1239c183d77cf3312c3bfff1918013371bb --- /dev/null +++ b/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vap.csv \ + --data_file_keys "video,vap_video" \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vap." \ + --output_path "./models/train/Video-As-Prompt-Wan2.1-14B_full" \ + --trainable_models "vap" \ + --extra_inputs "vap_video,input_image" \ + --use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh new file mode 100644 index 0000000000000000000000000000000000000000..3d580ab95b96c0b727dc2874d0178110a8eec85c --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.motion_controller." \ + --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_full" \ + --trainable_models "motion_controller" \ + --extra_inputs "motion_bucket_id" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh new file mode 100644 index 0000000000000000000000000000000000000000..baf98a9514b69d2c875cfeca8498b8c095d4fb96 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..45a99ded5ad105c775b4ee641164fb21e2ba780f --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..a202bf9890797c719ec2e67dda93ea35b0273394 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..8a17c3f1c6127ee21b01aa9a169bcb55bdd758f1 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..86feae73cab10e414de5f3a2c6e400285495dfaa --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh new file mode 100644 index 0000000000000000000000000000000000000000..b59ed32c5c844e40eaa053f5ccb2f9e3f5af9f0b --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..34273c1ec50a8b4850858bfc712088812c6005e6 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..f6eed97db8ab61860549704d3238db7426073e77 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh new file mode 100644 index 0000000000000000000000000000000000000000..41b87e990df97bff7521e81d418868d6f1d1e054 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..ce6640e6334f5f81372ff435c7856864594949ad --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..afb5d3dab59cebfc4af7628f63d90b107b87e522 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh new file mode 100644 index 0000000000000000000000000000000000000000..492898b46eb290647aed41b82f21ba1a60983035 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-480P_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d913591cacec356f9b73b7844d51462a8d82ef4 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 720 \ + --width 1280 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000000000000000000000000000000000000..e0d6e842ad12407fd5cbfc4bfa3c65618b2fa94c --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,12 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..ae804b0503cd0903e18b32bd11a18675d01a43dc --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-14B_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh new file mode 100644 index 0000000000000000000000000000000000000000..b348874f08e74f83d476444de143f4cfe4304dec --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh new file mode 100644 index 0000000000000000000000000000000000000000..763252e14ee78e74fe001157bfb136bac62577d9 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..c54926347b73cc5e96fce2fd8378161ea7cc1179 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-14B_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh b/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab09a78de92b242e98a61fb6ba46c4c875e1062f --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_animate.csv \ + --data_file_keys "video,animate_pose_video,animate_face_video" \ + --height 480 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.animate_adapter." \ + --output_path "./models/train/Wan2.2-Animate-14B_full" \ + --trainable_models "animate_adapter" \ + --extra_inputs "input_image,animate_pose_video,animate_face_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe85ca8fa85f31707e7e114da3c6aa59eb7eebb2 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh @@ -0,0 +1,35 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..6f5ac87c265a8cf6abc1cb42f451fe1768822679 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh @@ -0,0 +1,35 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video,reference_image" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "control_video,reference_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..7c623a05a6679feb48e3013e118f12852538462e --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh @@ -0,0 +1,33 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-InP_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-InP_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image,end_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..10fb02f2923d6c324aab9c3985a980661b45832a --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh @@ -0,0 +1,37 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900) \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..10c4a5a55728f1a5bdeb3adf2101799d9da639d0 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \ + --data_file_keys "video,input_audio,s2v_pose_video" \ + --height 448 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --trainable_models "dit" \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-S2V-14B_full" \ + --extra_inputs "input_image,input_audio,s2v_pose_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..89c070429723a70e15d83acdc447a097e131abca --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh @@ -0,0 +1,33 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.417 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [875, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.417 +# boundary corresponds to timesteps [0, 875) \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000000000000000000000000000000000000..def9f897a558dc9a4f97b1d98adff11201482986 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..ecfef3219f0a2e488c6d31890d2aa4054359e2ca --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh @@ -0,0 +1,42 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 \ + --initialize_model_on_cpu +# boundary corresponds to timesteps [900, 1000] + + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_full" \ + --trainable_models "vace" \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 \ + --initialize_model_on_cpu +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/accelerate_config_14B.yaml b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3875a9da2354631046baf19e61a9d5ab1d8d6aca --- /dev/null +++ b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/wanvideo/model_training/full/krea-realtime-video.sh b/examples/wanvideo/model_training/full/krea-realtime-video.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0c4c85c191a72877c696a44df81f08bdf0eac25 --- /dev/null +++ b/examples/wanvideo/model_training/full/krea-realtime-video.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "krea/krea-realtime-video:krea-realtime-video-14b.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/krea-realtime-video_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/LongCat-Video.sh b/examples/wanvideo/model_training/lora/LongCat-Video.sh new file mode 100644 index 0000000000000000000000000000000000000000..022048c9736bfc6f2fce94a1136ecea9ecfa757a --- /dev/null +++ b/examples/wanvideo/model_training/lora/LongCat-Video.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LongCat-Video_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "adaLN_modulation.1,attn.qkv,attn.proj,cross_attn.q_linear,cross_attn.kv_linear,cross_attn.proj,ffn.w1,ffn.w2,ffn.w3" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh b/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..c2c609aadc626e987c6de48aab0639fee14b15d1 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh @@ -0,0 +1,18 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vap.csv \ + --data_file_keys "video,vap_video" \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 10 \ + --model_id_with_origin_paths "ByteDance/Video-As-Prompt-Wan2.1-14B:transformer/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Video-As-Prompt-Wan2.1-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vap_video,input_image" \ + --use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/lora/Wan2.1-1.3b-mc-lora.sh b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-mc-lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..62a713339c1cdd9c8dc7d42e4e34e2afb2c3133f --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-mc-lora.sh @@ -0,0 +1,35 @@ +export CUDA_VISIBLE_DEVICES=1,2,3,4,5,6 +export WANDB_API_KEY="eeaead80e423958f2792b06d7eab6c61796e36d8" + +LOG_DIR="/data/rczhang/PencilFolder/DiffSynth-Studio/log" +mkdir -p "$LOG_DIR" + + +nohup accelerate launch examples/wanvideo/model_training/train_mc_lora.py \ + --dataset_base_path /data/rczhang/PencilFolder/diffusion-pipe/GF_MCData \ + --dataset_metadata_path /data/rczhang/PencilFolder/diffusion-pipe/GF_MCData/metadata.csv \ + --data_file_keys "video" \ + --height 360 \ + --width 640 \ + --dataset_repeat 1 \ + --lora_checkpoint "./models/train/Wan2.1-1.3b-mc-lora/epoch-3.safetensors" \ + --model_paths '[ + "/data/rczhang/PencilFolder/DiffSynth-Studio/models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "/data/rczhang/PencilFolder/DiffSynth-Studio/models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "/data/rczhang/PencilFolder/DiffSynth-Studio/models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" + ]' \ + --learning_rate 5e-5 \ + --num_epochs 10 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-1.3b-mc-lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 128 \ + --gradient_accumulation_steps 4 \ + --wandb_project "WanLoRA-Diffsyn" \ + --wandb_run_name "Wan2.1-1.3b-mc-lora" \ + --wandb_mode "online" \ + --wandb_log_every 10 \ + > "$LOG_DIR/Wan2.1-1.3b-mc-lora.out" 2>&1 & + +echo "训练脚本已挂后台,日志输出到 $LOG_DIR/Wan2.1-1.3b-mc-lora.out" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh new file mode 100644 index 0000000000000000000000000000000000000000..51ebfe45730f74c5bcfaaaf46bc3adb476d1b563 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "motion_bucket_id" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh new file mode 100644 index 0000000000000000000000000000000000000000..9a9622d55091685172ee2928e0842cc76018be4e --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..03c1f4517b7af9dbc65abe515af20f59c71067a4 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..d5f509be6956fe99291529b8eb3d737102302982 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..608df5ff2010ca89ca9ae6bec075ab658aed799f --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..37b251812c20c4012a63dde8d98b44b53b9bb1b9 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f809a477b7ba85f517c777060fa53efc977f711 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..1e7156df80d6f4a4a0dd752081c5feb6d465713c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..5879f59b57104067bb3efc77ebcbf8bf28f54df5 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh new file mode 100644 index 0000000000000000000000000000000000000000..176a05fb419cc4dee753d5d282feeb6c0fa6f280 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..3ead12ce7f9a6da33c08a4c15ccf2727ae4ce3ff --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video,reference_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..40a8ad07950173d1b43b3595770fa4e294d44853 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh new file mode 100644 index 0000000000000000000000000000000000000000..473d51981702868874720056795b96bf1fe337c7 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-480P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh new file mode 100644 index 0000000000000000000000000000000000000000..52b72bdfc7c38392b15bf07a9fde58808293df13 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh @@ -0,0 +1,19 @@ +# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA +# We tested on 8*80G GPUs +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 720 \ + --width 1280 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --initialize_model_on_cpu \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000000000000000000000000000000000000..d16a287193286ac4893878325b1a9b6076eab627 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..1fb55ac3467dca84420013e999b2bcc4e8d689d8 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh new file mode 100644 index 0000000000000000000000000000000000000000..2bcb55b9b6c1e4b02941a85f9fecc2f9ca470698 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh @@ -0,0 +1,17 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh new file mode 100644 index 0000000000000000000000000000000000000000..b56507889bd957563e874c64f2c983c09db4eeb8 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -0,0 +1,17 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..633ea0e305b102a5b388670507429bdea869e3c5 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh @@ -0,0 +1,18 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.1-VACE-14B_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..0b6e5711be2cb9e1118dbb49075a1e7821864189 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh @@ -0,0 +1,20 @@ +# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA +# We tested on 8*80G GPUs +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_animate.csv \ + --data_file_keys "video,animate_pose_video,animate_face_video" \ + --height 480 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Animate-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,animate_pose_video,animate_face_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh new file mode 100644 index 0000000000000000000000000000000000000000..1a9983b43166e7ae18f4fdd3b9818eda084aa632 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh @@ -0,0 +1,39 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control-Camera:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control-Camera:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh new file mode 100644 index 0000000000000000000000000000000000000000..571ae5400d573c529de0043483755bcddc15964a --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh @@ -0,0 +1,39 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video,reference_image" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-Control:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-Control:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "control_video,reference_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh new file mode 100644 index 0000000000000000000000000000000000000000..491351c9492944297c053d12a74470025aaf79a0 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh @@ -0,0 +1,37 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-InP_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-Fun-A14B-InP:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-Fun-A14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-Fun-A14B-InP:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Fun-A14B-InP_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,end_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d9eba0aadd1f2c11d7ff6db199f1ad1e2536b59 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh @@ -0,0 +1,39 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900) \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..510796bad4eb12cbf1041d368fde7479be92e067 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh @@ -0,0 +1,18 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \ + --data_file_keys "video,input_audio,s2v_pose_video" \ + --height 448 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-S2V-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,input_audio,s2v_pose_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..f47c96b7197a99c3e1332c5a858f0a9755130ead --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh @@ -0,0 +1,38 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.417 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [875, 1000] + + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.417 +# boundary corresponds to timesteps [0, 875) \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a33b5799e71f6fc68cc00386ff28b1368e2673b --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh new file mode 100644 index 0000000000000000000000000000000000000000..93b38cfd11bed9a315dd87d2b5a9fe857ced27fe --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh @@ -0,0 +1,43 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_high_noise_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ + --data_file_keys "video,vace_video,vace_reference_image" \ + --height 480 \ + --width 832 \ + --num_frames 17 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.2-VACE-Fun-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,PAI/Wan2.2-VACE-Fun-A14B:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.2-VACE-Fun-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.vace." \ + --output_path "./models/train/Wan2.2-VACE-Fun-A14B_low_noise_lora" \ + --lora_base_model "vace" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "vace_video,vace_reference_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 +# boundary corresponds to timesteps [0, 900] \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/krea-realtime-video.sh b/examples/wanvideo/model_training/lora/krea-realtime-video.sh new file mode 100644 index 0000000000000000000000000000000000000000..94c64d17d4e2aacb4842b54f9064ac3440a60a7b --- /dev/null +++ b/examples/wanvideo/model_training/lora/krea-realtime-video.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "krea/krea-realtime-video:krea-realtime-video-14b.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/krea-realtime-video_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/prepare_instancev_data.py b/examples/wanvideo/model_training/prepare_instancev_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8f277fa196a4662d2d69465213a5fc002a9b56 --- /dev/null +++ b/examples/wanvideo/model_training/prepare_instancev_data.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +""" +InstanceV 训练数据预处理脚本 + +功能: +1. 读取 InstanceCap.jsonl(包含 Video, Global Description, Structural Description) +2. 匹配 InstanceLabel 中的 mask 目录 +3. 匹配 OpenVid1M-Video 中的视频文件 +4. 生成适合训练的 JSONL 格式 + +数据目录结构: + data/InstanceCap/InstanceCap.jsonl # 原始标注 + data/InstanceLabel/{video_name}_masks/ # 每个 instance 的 mask 序列 + data/OpenVid1M-Video/{video_name}.mp4 # 原始视频 + +输出格式(每行一个 JSON): +{ + "video": "path/to/video.mp4", + "prompt": "global description", + "instance_prompts": ["instance 0 description", "instance 1 description", ...], + "instance_mask_dirs": ["path/to/masks/No.0", "path/to/masks/No.1", ...] +} +""" + +import os +import json +import argparse +from pathlib import Path +from collections import defaultdict +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser(description="Prepare InstanceV training data") + parser.add_argument( + "--instancecap_path", + type=str, + default="/data/rczhang/PencilFolder/data/InstanceCap/InstanceCap.jsonl", + help="Path to InstanceCap.jsonl", + ) + parser.add_argument( + "--instance_label_dir", + type=str, + default="/data/rczhang/PencilFolder/data/InstanceLabel", + help="Directory containing instance masks", + ) + parser.add_argument( + "--video_dir", + type=str, + default="/data/rczhang/PencilFolder/data/OpenVid1M-Video", + help="Directory containing source videos", + ) + parser.add_argument( + "--output_path", + type=str, + default="/data/rczhang/PencilFolder/data/instancev_train.jsonl", + help="Output JSONL path", + ) + parser.add_argument( + "--min_instances", + type=int, + default=1, + help="Minimum number of instances required", + ) + parser.add_argument( + "--max_instances", + type=int, + default=10, + help="Maximum number of instances to keep", + ) + parser.add_argument( + "--use_dense_caption", + action="store_true", + help="Use dense caption format for instance prompts", + ) + return parser.parse_args() + + +def get_video_name_from_path(video_path: str) -> str: + """从 video 路径提取 video name(不含扩展名)""" + return Path(video_path).stem + + +def find_mask_dirs(instance_label_dir: str, video_name: str) -> dict: + """ + 查找某个视频对应的所有 instance mask 目录 + + mask 目录结构: {video_name}_masks/ + mask 文件命名: {frame_id:06d}_No.{instance_id}.png + + Returns: + dict: {instance_id: [mask_file_paths]} + """ + mask_dir = os.path.join(instance_label_dir, f"{video_name}_masks") + if not os.path.isdir(mask_dir): + return {} + + instance_masks = defaultdict(list) + for fname in sorted(os.listdir(mask_dir)): + if not fname.endswith(".png"): + continue + # 解析文件名: 000000_No.0.png + parts = fname.replace(".png", "").split("_No.") + if len(parts) != 2: + continue + frame_id_str, inst_id_str = parts + try: + frame_id = int(frame_id_str) + inst_id = int(inst_id_str) + except ValueError: + continue + instance_masks[inst_id].append(os.path.join(mask_dir, fname)) + + return dict(instance_masks) + + +def build_instance_prompt(instance_info: dict) -> str: + """ + 从 InstanceCap 的 instance 信息构建 prompt + + instance_info 结构: + { + "Class": "person", + "Appearance": "...", + "Actions and Motion": "...", + "Position": "..." + } + """ + cls = instance_info.get("Class", "object") + appearance = instance_info.get("Appearance", "") + actions = instance_info.get("Actions and Motion", "") + + # 构建精简但信息丰富的 prompt + parts = [] + if cls: + parts.append(f"A {cls}") + if appearance: + parts.append(appearance) + if actions: + parts.append(actions) + + prompt = ". ".join(parts) + # 清理多余空格和标点 + prompt = prompt.replace("..", ".").replace(" ", " ").strip() + if prompt and not prompt.endswith("."): + prompt += "." + return prompt + + +def process_sample( + sample: dict, + instance_label_dir: str, + video_dir: str, + min_instances: int, + max_instances: int, +) -> dict | None: + """ + 处理单个样本,返回训练格式的 dict 或 None + """ + video_name = get_video_name_from_path(sample.get("Video", "")) + if not video_name: + return None + + # 检查视频是否存在 + video_path = os.path.join(video_dir, f"{video_name}.mp4") + if not os.path.isfile(video_path): + return None + + # 查找 mask 目录 + instance_masks = find_mask_dirs(instance_label_dir, video_name) + if len(instance_masks) < min_instances: + return None + + # 提取 instance 信息 + struct_desc = sample.get("Structural Description", {}) + main_instances = struct_desc.get("Main Instance", {}) + + instance_prompts = [] + instance_mask_dirs = [] + + for inst_key in sorted(main_instances.keys()): + # inst_key: "No.0", "No.1", ... + try: + inst_id = int(inst_key.replace("No.", "")) + except ValueError: + continue + + if inst_id not in instance_masks: + continue + + inst_info = main_instances[inst_key] + prompt = build_instance_prompt(inst_info) + if not prompt: + continue + + instance_prompts.append(prompt) + # 存储整个 mask 目录路径(而非单个文件列表,训练时动态加载) + mask_dir = os.path.dirname(instance_masks[inst_id][0]) + instance_mask_dirs.append({ + "mask_dir": mask_dir, + "instance_id": inst_id, + "num_frames": len(instance_masks[inst_id]), + }) + + if len(instance_prompts) >= max_instances: + break + + if len(instance_prompts) < min_instances: + return None + + # 构建输出 + global_desc = sample.get("Global Description", "") + background = struct_desc.get("Background Detail", "") + camera = struct_desc.get("Camera Movement", "") + + # 合并为完整 prompt + full_prompt_parts = [global_desc] + if background: + full_prompt_parts.append(background) + if camera: + full_prompt_parts.append(camera) + full_prompt = " ".join(full_prompt_parts) + + return { + "video": video_path, + "prompt": full_prompt, + "instance_prompts": instance_prompts, + "instance_mask_dirs": instance_mask_dirs, + } + + +def main(): + args = parse_args() + + # 读取 InstanceCap + print(f"Loading InstanceCap from {args.instancecap_path}") + samples = [] + with open(args.instancecap_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + samples.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Warning: Failed to parse line: {e}") + continue + + print(f"Loaded {len(samples)} samples") + + # 处理每个样本 + output_samples = [] + for sample in tqdm(samples, desc="Processing samples"): + result = process_sample( + sample, + args.instance_label_dir, + args.video_dir, + args.min_instances, + args.max_instances, + ) + if result is not None: + output_samples.append(result) + + print(f"Valid samples: {len(output_samples)} / {len(samples)}") + + # 写入输出 + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path, "w", encoding="utf-8") as f: + for sample in output_samples: + f.write(json.dumps(sample, ensure_ascii=False) + "\n") + + print(f"Saved to {args.output_path}") + + # 打印统计信息 + if output_samples: + avg_instances = sum(len(s["instance_prompts"]) for s in output_samples) / len(output_samples) + print(f"Average instances per sample: {avg_instances:.2f}") + + +if __name__ == "__main__": + main() + diff --git a/examples/wanvideo/model_training/prepare_instancev_iground.py b/examples/wanvideo/model_training/prepare_instancev_iground.py new file mode 100644 index 0000000000000000000000000000000000000000..f3448e6124cbf79c3b9779e7ffcf05dfd39f5ab9 --- /dev/null +++ b/examples/wanvideo/model_training/prepare_instancev_iground.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Prepare InstanceV training data from iGround processed JSONL. + +Outputs per-line JSON with: +{ + "video": "relative/path/to/clip.mp4", + "prompt": "caption", + "instance_prompts": ["phrase1", "phrase2", ...], + "instance_mask_dirs": [ + {"mask_dir": "/abs/path/to/masks", "instance_id": 0, "num_frames": 49}, + ... + ] +} +""" + +import argparse +import json +import math +import os +from pathlib import Path + +import imageio.v2 as imageio +from PIL import Image, ImageDraw +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser(description="Prepare InstanceV data from iGround") + parser.add_argument( + "--iground_jsonl", + type=str, + default="/data/rczhang/PencilFolder/data/iGround/iGround_train_set_processed.jsonl", + help="Path to iGround processed JSONL.", + ) + parser.add_argument( + "--clips_dir", + type=str, + default="/data/rczhang/PencilFolder/data/iGround/Clips/train", + help="Directory containing iGround clips.", + ) + parser.add_argument( + "--mask_root_dir", + type=str, + default="/data/rczhang/PencilFolder/data/iGround/InstanceMasks/train", + help="Root directory to store generated instance masks.", + ) + parser.add_argument( + "--output_metadata", + type=str, + default="/data/rczhang/PencilFolder/data/iGround/instancev_iground_train.jsonl", + help="Output metadata JSONL path.", + ) + parser.add_argument( + "--dataset_base_path", + type=str, + default="/data/rczhang/PencilFolder/data", + help="Base path used by UnifiedDataset (video paths will be relative to this).", + ) + parser.add_argument( + "--min_instances", + type=int, + default=1, + help="Minimum number of instances required.", + ) + parser.add_argument( + "--max_instances", + type=int, + default=None, + help="Maximum number of instances to keep (None = keep all).", + ) + parser.add_argument( + "--overwrite_masks", + action="store_true", + help="Overwrite existing masks for a clip.", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Limit number of samples for debugging.", + ) + return parser.parse_args() + + +def _safe_relpath(path: str, base_path: str) -> str: + if not base_path: + return path + return os.path.relpath(path, base_path) + + +def _clamp_bbox(bbox, width: int, height: int): + if not bbox or len(bbox) != 4: + return None + x0, y0, x1, y1 = bbox + left = max(0, int(math.floor(x0))) + top = max(0, int(math.floor(y0))) + right = min(width, int(math.ceil(x1))) + bottom = min(height, int(math.ceil(y1))) + if right <= left or bottom <= top: + return None + return left, top, right, bottom + + +def _collect_visible_phrases(phrases, labels_per_frame): + visible = set() + for labels in labels_per_frame: + for label in labels: + visible.add(label) + return [p for p in phrases if p in visible] + + +def _write_masks( + mask_dir: str, + phrases, + labels_per_frame, + bboxes_per_frame, + width: int, + height: int, + overwrite: bool, +): + if os.path.isdir(mask_dir) and not overwrite: + return + os.makedirs(mask_dir, exist_ok=True) + + phrase_set = set(phrases) + num_frames = len(bboxes_per_frame) + for frame_idx in range(num_frames): + labels = labels_per_frame[frame_idx] + bboxes = bboxes_per_frame[frame_idx] + frame_map = {} + for label, bbox in zip(labels, bboxes): + if label in phrase_set: + frame_map[label] = bbox + + for inst_id, phrase in enumerate(phrases): + mask = Image.new("L", (width, height), 0) + bbox = frame_map.get(phrase) + if bbox is not None: + coords = _clamp_bbox(bbox, width, height) + if coords is not None: + draw = ImageDraw.Draw(mask) + draw.rectangle(coords, fill=255) + mask_path = os.path.join(mask_dir, f"{frame_idx:06d}_No.{inst_id}.png") + mask.save(mask_path) + + +def _is_video_readable(video_path: str) -> bool: + try: + reader = imageio.get_reader(video_path) + try: + reader.get_data(0) + finally: + reader.close() + except Exception: + return False + return True + + +def main(): + args = parse_args() + + Path(args.mask_root_dir).mkdir(parents=True, exist_ok=True) + Path(os.path.dirname(args.output_metadata)).mkdir(parents=True, exist_ok=True) + + processed = 0 + skipped_missing_video = 0 + skipped_instances = 0 + skipped_unreadable = 0 + wrote = 0 + + with open(args.iground_jsonl, "r", encoding="utf-8") as f_in, open( + args.output_metadata, "w", encoding="utf-8" + ) as f_out: + for line in tqdm(f_in, desc="Processing iGround"): + if args.limit is not None and wrote >= args.limit: + break + line = line.strip() + if not line: + continue + processed += 1 + sample = json.loads(line) + + video_id = sample["video_id"] + clip_id = sample["clip_id"] + clip_name = f"{video_id}_{clip_id}.mp4" + clip_path = os.path.join(args.clips_dir, clip_name) + if not os.path.isfile(clip_path): + skipped_missing_video += 1 + continue + if not _is_video_readable(clip_path): + skipped_unreadable += 1 + continue + + phrases = list(sample.get("phrases", [])) + labels_per_frame = sample.get("labels", []) + bboxes_per_frame = sample.get("bboxes", []) + if not phrases or not labels_per_frame or not bboxes_per_frame: + skipped_instances += 1 + continue + + visible_phrases = _collect_visible_phrases(phrases, labels_per_frame) + if args.max_instances is not None: + visible_phrases = visible_phrases[: args.max_instances] + + if len(visible_phrases) < args.min_instances: + skipped_instances += 1 + continue + + width = int(sample["width"]) + height = int(sample["height"]) + + mask_dir = os.path.join(args.mask_root_dir, f"{video_id}_{clip_id}_masks") + _write_masks( + mask_dir, + visible_phrases, + labels_per_frame, + bboxes_per_frame, + width, + height, + overwrite=args.overwrite_masks, + ) + + instance_mask_dirs = [ + { + "mask_dir": mask_dir, + "instance_id": inst_id, + "num_frames": len(bboxes_per_frame), + } + for inst_id in range(len(visible_phrases)) + ] + + entry = { + "video": _safe_relpath(clip_path, args.dataset_base_path), + "prompt": sample.get("caption", ""), + "instance_prompts": visible_phrases, + "instance_mask_dirs": instance_mask_dirs, + } + f_out.write(json.dumps(entry, ensure_ascii=False) + "\n") + wrote += 1 + + print("Done.") + print(f"Processed: {processed}") + print(f"Wrote: {wrote}") + print(f"Skipped (missing video): {skipped_missing_video}") + print(f"Skipped (unreadable video): {skipped_unreadable}") + print(f"Skipped (insufficient instances): {skipped_instances}") + + +if __name__ == "__main__": + main() diff --git a/examples/wanvideo/model_training/prepare_instancev_instancecap_bbox.py b/examples/wanvideo/model_training/prepare_instancev_instancecap_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..40076aa55032f4e581ecc23969d48491e7af0a35 --- /dev/null +++ b/examples/wanvideo/model_training/prepare_instancev_instancecap_bbox.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 +""" +Prepare InstanceV training data from InstanceCap + InstanceCap-BBox. + +Outputs per-line JSON: +{ + "video": "OpenVid1M-Video-InstanceCap/