- [2024/10] 🔥 We open source the inference code and Gradio demo for HART!
We introduce Hybrid Autoregressive Transformer (HART), an autoregressive (AR) visual generation model capable of directly generating 1024x1024 images, rivaling diffusion models in image generation quality. Existing AR models face limitations due to the poor image reconstruction quality of their discrete tokenizers and the prohibitive training costs associated with generating 1024px images. To address these challenges, we present the hybrid tokenizer, which decomposes the continuous latents from the autoencoder into two components: discrete tokens representing the big picture and continuous tokens representing the residual components that cannot be represented by the discrete tokens. The discrete component is modeled by a scalable-resolution discrete AR model, while the continuous component is learned with a lightweight residual diffusion module with only 37M parameters. Compared with the discrete-only VAR tokenizer, our hybrid approach improves reconstruction FID from 2.11 to 0.30 on MJHQ-30K, leading to a 31% generation FID improvement from 7.85 to 5.38. HART also outperforms state-of-the-art diffusion models in both FID and CLIP score, with 4.5-7.7x higher throughput and 6.9-13.4x lower MACs.
Download the repo:
git clone https://github.com/mit-han-lab/hart
cd hart
conda create -n hart python=3.10
conda activate hart
conda install -c nvidia cuda-toolkit -y
pip install -e .
cd hart/kernels && python setup.py install
Download Qwen2-VL-1.5B-Instruct
git clone https://huggingface.co/mit-han-lab/Qwen2-VL-1.5B-Instruct
Download HART tokenizer and models
git clone https://huggingface.co/mit-han-lab/hart-0.7b-1024px
Download the safety check model:
git clone https://huggingface.co/google/shieldgemma-2b
Note: We use ShieldGemma-2B from Google DeepMind to filter out unsafe prompts in our demo. We strongly recommend using it if you are distributing our demo publicly.
You may launch the Gradio demo using the following script:
python app.py ---model_path /path/to/model \
--text_model_path /path/to/Qwen2 \
--shield_model_path /path/to/ShieldGemma2B
Please notice that for model_path, please point it to the llm
folder under our pretrained checkpoint. For example, if your model is stored at checkpoints/hart-0.7b-1024px
, then model_path
should be checkpoints/hart-0.7b-1024px/llm
. Similar for all commands below.
- Sampling with single prompt:
python sample.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2 \
--prompt "YOUR_PROMPT" \
--sample_folder_dir /path/to/save_dir \
--shield_model_path /path/to/ShieldGemma2B
- Sampling with multiple prompts:
# You can add --store_separately to store each image individually, otherwise images will be stored in one grid.
python sample.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2 \
--prompt_list [Prompt1, Prompt2, ..., PromptN] \
--sample_folder_dir /path/to/save_dir \
--shield_model_path /path/to/ShieldGemma2B
Please use the following prompt to perform latency benchmarking:
python latency_profile.py --model_path /path/to/model \
--text_model_path /path/to/Qwen2
Our codebase is inspired by amazing open source research projects such as VAR and MAR. The authors would like to thank Tianhong Li from MIT, Lijun Yu from Google DeepMind, Kaiwen Zha from MIT and Yunhao Fang from UCSD for helpful discussions; and Paul Palei, Mike Hobbs, Chris Hill, Michel Erb from MIT for setting up the online demo and maintaining the server.
@article{tang2024hart,
title={HART: Efficient Visual Generation with Hybrid Autoregressive Transformer},
author={Tang, Haotian and Wu, Yecheng and Yang, Shang and Xie, Enze and Chen, Junsong and Chen, Junyu and Zhang, Zhuoyang and Cai, Han and Lu, Yao and Han, Song},
journal={arXiv preprint},
year={2024}
}