Script to train a Latent Diffusion Model based on Pinaya et al. "Brain imaging generation with latent diffusion models. " on the MIMIC-CXR dataset using MONAI Generative Models package.
After downloading the JPG images from MIMIC-CXR-JPG and the associated free-text reports from MIMIC-CXR Database, you need to preprocess the data. The following is the list of execution for preprocessing:
src/python/preprocessing/organise.py
- Resizes dataset to 512 pixels in the smaller dimensionsrc/python/preprocessing/create_ids.py
- Create files with datalist for training, validation and test using only "PA" viewssrc/python/preprocessing/create_section_files.py
- Create file with text sections for each report.src/python/preprocessing/create_sentences_files.py
- Create file with sentences for each report.
After preprocessing, you can train the model using similar commands as in the following files (note: This project was executed on a cluster with RunAI platform):
cluster/runai/training/stage1.sh
- Command to start to execute in the server the training the first stage of the model. The main python script in for this is thesrc/python/training/train_aekl.py
script. The--volume
flags indicate how the dataset is mounted in the Docker container.src/python/training/eda_ldm_scaling_factor.py
- Script to find the best scaling factor for the latent diffusion model.cluster/runai/training/ldm.sh
- Command to start to execute in the server the training the diffusion model on the latent representation. The main python script in for this is thesrc/python/training/train_ldm.py
script. The--volume
flags indicate how the dataset is mounted in the Docker container.
These .sh
files indicates which parameters and configuration file was used for training, as well how the host directories
were mounted in the used Docker container.
Finally, we converted the mlflow model to .pth files (for easly loading in MONAI), sampled images from the diffusion model, and evaluated the model. The following is the list of execution for inference and evaluation:
src/python/testing/convert_mlflow_to_pytorch.py
- Convert mlflow model to .pth filessrc/python/testing/sample_images.py
- Sample images from the diffusion model.cluster/runai/testing/sampling_unconditioned.sh
shows how to execute this script in the server to generate the 1000 samples used in the following scripts.src/python/testing/compute_msssim_reconstruction.py
- Measure the mean structural similarity index between images and reconstruction to measure the preformance of the first stage.src/python/testing/compute_msssim_sample.py
- Measure the mean structural similarity index between test images and samples in order to measure the diversity of the synthetic data.src/python/testing/compute_msssim_test_set.py
- Measure the mean structural similarity index between test images to measure the diversity of the reference test set.src/python/testing/compute_fid.py
- Compute FID score between generated images and real images.
- Version 0.1 - (Mar 9, 2023) Initial release
- Version 0.2 - (Apr 9, 2023) Model with flipped images fixed.