-
Hi there, I am trying to adapt the 3D LDM tutorial to our data using the pytorch-lightning framework and I am struggling to get it running properly. My problem is that while the training / validation loss curves are looking very fine, the sampled synthetic images I am logging to tensorboard each 10 epochs do show no progress at all. I have no clue, if I have made so mistake in transposing the code to Any help is very appreciated!! Currently, the implementation is as follows: # Tutorials:
# https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/3d_ldm/3d_ldm_tutorial.py
# https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/08-deep-autoencoders.html
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
from monai.visualize import plot_2d_or_3d_image
from torch.nn import L1Loss
from torch.cuda.amp import GradScaler, autocast
from generative.inferers import LatentDiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler
import pandas as pd
import logging
from backend.datamodule import DataModule
def _compute_scale_factor(metadata: pd.DataFrame, dataset_dir: str, ae_model, batch_size: int = 8, use_persistent_cache: bool = False, device: str = "cuda:0"):
# define scaling factor (https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/3d_ldm/3d_ldm_tutorial.py#L314)
dm = DataModule(
metadata=metadata,
dataset_dir=dataset_dir,
batch_size=batch_size,
use_persistent_cache=use_persistent_cache
)
dm.setup()
batch_sizes = []
scale_factors = []
# generate iterator
for batch in dm.train_dataloader():
batch_sizes.append(len(batch["image"]))
with torch.no_grad():
with autocast(enabled=True):
z = ae_model.autoencoder.encode_stage_2_inputs(batch["image"].to(device))
scale_factors.append(1/torch.std(z).item())
# calc mean
mean_scale_factor = np.sum(np.array(scale_factors) * np.array(batch_sizes)) / np.sum(np.array(batch_sizes))
msg = f"Scaling factor set to {mean_scale_factor}"
print(msg)
logging.info(msg)
return mean_scale_factor
class LDMUNet(L.LightningModule):
def __init__(self, **kwargs):
super().__init__()
# log hyperparameters
self.save_hyperparameters(ignore=['ae_model'])
self.automatic_optimization = True
self.training_step_outputs = []
self.validation_step_outputs = []
self.test_step_outputs = []
self.autoencoder = kwargs["ae_model"].autoencoder
self.autoencoder.requires_grad_(False)
self.z_example = None
# only log images from first batch of every tenth val-epoch
self.logged_valid_imgs = False
# finally, define network
self.unet = DiffusionModelUNet(
spatial_dims=3,
in_channels=3,
out_channels=3,
num_res_blocks=2,
num_channels=(64, 128, 256, 512),
attention_levels=(False, False, True, True),
norm_num_groups=32,
norm_eps=1e-6,
resblock_updown=False,
num_head_channels=8,
#num_head_channels=(0, 0, 512, 512),
with_conditioning=False,
transformer_num_layers=1,
cross_attention_dim=None,
num_class_embeds=None,
upcast_attention=False,
use_flash_attention=False
)
self.scheduler = DDPMScheduler(
num_train_timesteps=1000,
schedule="scaled_linear_beta",
beta_start=0.0015,
beta_end=0.0205
)
self.inferer = LatentDiffusionInferer(
scheduler=self.scheduler,
scale_factor=self.hparams.scale_factor,
#sldm_latent_shape=[32, 12, 48, 48] # monai-generative from git: 894f2ec3452d1bedbcd72e9842e1fa1fb4535ce9
)
self.l1_loss = L1Loss()
def forward(self, x):
pass
def configure_optimizers(self):
optimizer = torch.optim.Adam(
params=self.unet.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay
)
return optimizer
def training_step(self, batch, batch_idx):
step_results = self._shared_step(batch=batch, batch_idx=batch_idx, prefix="train")
self.log(
name="train_loss",
value=step_results["loss"],
prog_bar=True,
logger=False,
on_step=True,
on_epoch=True,
sync_dist=self.hparams.sync_dist,
batch_size=step_results["batch_size"]
)
self.training_step_outputs.append(step_results)
return step_results
def on_train_epoch_end(self):
self._shared_epoch_end(self.training_step_outputs, "train")
self.training_step_outputs.clear()
def validation_step(self, batch, batch_idx):
step_results = self._shared_step(batch=batch, batch_idx=batch_idx, prefix="valid")
self.log(
name="valid_loss",
value=step_results["loss"],
prog_bar=True,
logger=False,
on_step=True,
on_epoch=True,
sync_dist=self.hparams.sync_dist,
batch_size=step_results["batch_size"]
)
self.validation_step_outputs.append(step_results)
return step_results
def on_validation_epoch_end(self):
self._shared_epoch_end(self.validation_step_outputs, "valid")
# reset logging for next validation step
self.logged_valid_imgs = False
self.validation_step_outputs.clear()
def _shared_step(self, batch, batch_idx, prefix):
x = batch["image"]
ids = batch["name"]
if self.z_example is None:
self.z_example = self.autoencoder.encode_stage_2_inputs(x)
print(f"Size of latent space images: {tuple(self.z_example.size())[1:]}")
loss_dict = self._shared_ldm(
images=x,
ids = ids,
prefix=prefix
)
return dict(loss_dict, **{"batch_size": len(x), "device": x.device})
def _shared_ldm(self, images, ids, prefix):
# Generate random noise
# noise = torch.randn_like(self.z_example)
# --> this does not work, if batch-size is different
input_size = ((len(images), ) + tuple(self.z_example.size())[1:])
noise = torch.rand(
input_size,
dtype=self.z_example.dtype,
layout=self.z_example.layout,
device=images.device
)
# Create timesteps
timesteps = torch.randint(
0, self.inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()
# Get model prediction
noise_pred = self.inferer(
inputs=images,
autoencoder_model=self.autoencoder,
diffusion_model=self.unet,
noise=noise,
timesteps=timesteps
)
#loss = F.mse_loss(noise_pred.float(), noise.float())
loss = self.l1_loss(input=noise_pred.float(), target=noise.float())
if prefix == "valid" and not self.logged_valid_imgs:
if self.current_epoch % 10 == 0 and str(images.device) == self.hparams.log_device:
noise = torch.randn(((8, ) + tuple(self.z_example.size())[1:]))
noise = noise.to(images.device)
self.scheduler.set_timesteps(num_inference_steps=1000)
synthetic_images = self.inferer.sample(
input_noise=noise,
autoencoder_model=self.autoencoder,
diffusion_model=self.unet,
scheduler=self.scheduler
)
synth_images = synthetic_images.detach().cpu().numpy()
for _i in range(0, synth_images.shape[0]):
plot_2d_or_3d_image(
data=synth_images,
step=self.global_step,
writer=self.logger.experiment,
index=_i,
max_channels=1,
frame_dim=-3,
tag="synthetic_img/" + str(_i + 1)
)
self.logged_valid_imgs = True
return {"loss": loss}
def _shared_epoch_end(self, outputs, prefix):
# concat batch sizes
batch_sizes = np.stack(
[x["batch_size"] for x in outputs]
)
# concat losses
losses = np.stack(
[x["loss"].item() for x in outputs]
)
# calculating weighted mean loss
avg_loss = np.sum(losses * batch_sizes) / np.sum(batch_sizes)
self.log(
name="loss/" + prefix,
value=avg_loss,
prog_bar=False,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=self.hparams.sync_dist
) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 16 replies
-
Can i see the params for your AE model? And how long have you trained for? |
Beta Was this translation helpful? Give feedback.
I found my mistake: I've accidentally used
torch.rand()
, which samples noise from a uniform distribution instead oftorch.randn()
, sampling from a normal distribution. So the batch-size-generic code to sample noise should look like this: