-
Notifications
You must be signed in to change notification settings - Fork 943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Stable Diffusion 3 Example #2558
Add Stable Diffusion 3 Example #2558
Conversation
Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, thanks for adding this. Would you mind replacing the sample image with a jpg version? (the png version you attached takes almost 1MB which is not great for the repo size)
@LaurentMazare Thank you for reminding me this. The sample image has been replaced by a JPG. The original PNG should be excluded from Git objects after squash-merging. |
Merged, thanks a lot! |
its so great to have this. thanks for the work. stable diffusion 3.5 Large is out now and it looks amazing. as its a full base model that we can train on. looks like its not working with candle yet though. |
@super-fun-surf working on it. SD3.5 changes the MMDiT archetecture a little bit (namely MMDiT-X). Needs to get that done first before implementing a working example. |
@LaurentMazare That's awesome! I'll take a look. |
It appears the focus of the community has been largely shifted to Flux.dev1. So the main purpose of this PR is to demonstrate the capability of Candle and serves a smoke-test to the MMDiT (#2397).
As such, I intend to minimize the intrusive change to the existing stable-diffusion codebase, such as using renaming function to adapt the VAE var-builder to the official safetensor weights of SD3 VAE. Still, there are some changes I have to make to
candle_nn::stable_diffusion
to support the CLIP and VAE of SD3, including:forward_until_encoder_layer
toClipTextTransformer
. The Comfy implementation for SD3 uses the penultimate hidden layer of CLIP-l and CLIP-g instead of the final layer (see sd3_clip.py and sdxl_clip.py). This practice, although not mentioned in the SD3 tech report, is referred and specified in Chapter 2.1 of the SDXL tech report.use_quant_conv
anduse_post_quant_conv
options to theAutoEncoderKL
, as SD3's VAE does not have those layers. These changes might be considered unspecific to SD3, asdiffusers
has these options supported.get_qkv_linear
to load the attention block incandle_nn::stable-diffusion::attention
, as some weight of linear layer of VAE in official SD3 Medium safetensors follow the dimension convention of(channel, channel, 1, 1)
instead of the regular(channel, channel)
that is natually supported bynn::linear
constructor.These changes allows reusing existing CLIP and VAE implementations, but inevitably add complexity to existing codebase. @LaurentMazare Let me know if these intrusive changes are justified. We may consider alternatives like re-implementing VAE and CLIP from scratch.
On top of these changes, I added the support to flash-attention for MMDiT based on whether the feature
flash-attn
is enabled. Also done a simple performance benchmark on GPUs like 3090 Ti and 4090.A side note is the T5 implementation on current main branch hasn't supported for FP16. I attempted to insert simple clampings within the FP16 dynamic range but it didn't work well on my GPUs. Looks like I need to wait for a more sophiscated implementation such as #2481. So for now, I use two different VarBuilders, one maps weights in safetensor into FP32 specifically for T5, the other for the rest compoents.