Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sd3.5 medium and MMDiT-X #2587

Merged
merged 5 commits into from
Oct 30, 2024

Conversation

Czxck001
Copy link
Contributor

@Czxck001 Czxck001 commented Oct 30, 2024

Stable Diffusion 3.5 Medium has been released on Oct 29, with modified archetecture named MMDiT-X. This PR adds support to the Stable Diffusion 3.5 Medium and MMDiT-X model.

Change is based on reference design sd3.5/mmdit-x.py in comparison with sd3-ref/mmdit.py, including

  • an extra self-attention for the MMDiT-X block is placed here.
  • modified pre_attention for x_block (code here)
  • modified post-attention for x_block (code here)
  • and different block-joining (between x and context) is present here.

Note: A change has been made in sd3.5 after the release of Stable Diffusion 3.5 Medium on Oct 29 that fixes some bugs in the original reference design.

Implementation-wise,

  • a trait polymorphism is kept between the old and new JointBlock, but individual DiTBlock is re-implemented to avoid coupling. Ad-hoc adaptation to original DiTBlock has been attempted and dropped as it seems less sensible in terms of software engineering.
  • SD3.5 has the X-block in the first 12 layers out of total 24 layers of JointBlock changed to an extra attention attn2 side track (namely "Self Attention"). None of the context-blocks have this extra attention. So the MMDiTXJointBlock is set to use this specification without further generalization.

References:

Sample image generated with Stable Diffusion 3.5 Medium:
out

depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 384,
Copy link
Contributor Author

@Czxck001 Czxck001 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notably, pos_embed_max_size for position embedding has been increased for SD3.5-medium, making it even bigger than SD3.5-large, which kept the original size of that of SD3-medium...

@LaurentMazare
Copy link
Collaborator

Amazing, thanks!

@LaurentMazare LaurentMazare merged commit d232e13 into huggingface:main Oct 30, 2024
10 checks passed
@Czxck001 Czxck001 deleted the support-sd3.5-medium branch October 30, 2024 05:19
super-fun-surf pushed a commit to aifx-art/candle that referenced this pull request Oct 30, 2024
* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants