Skip to content
This repository has been archived by the owner on Oct 22, 2023. It is now read-only.

Use diffusers EMAModel #126

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions scripts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from torchvision.transforms import functional
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -1077,6 +1078,7 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True):
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
unwrapped_unet = accelerator.unwrap_model(unet,True)
if args.use_ema:
ema_unet.store(unwrapped_unet.parameters())
ema_unet.copy_to(unwrapped_unet.parameters())

pipeline = DiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -1229,6 +1231,9 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True):
elif save_model == False and len(imgs) > 0:
del imgs
print(f"{bcolors.OKGREEN}Samples saved to {sample_dir}{bcolors.ENDC}")
if args.use_ema:
ema_unet.restore(unwrapped_unet.parameters())

except Exception as e:
print(e)
print(f"{bcolors.FAIL} Error occured during sampling, skipping.{bcolors.ENDC}")
Expand Down
58 changes: 0 additions & 58 deletions scripts/trainer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,61 +425,3 @@ def get_depth_image_path(self,image_path):
image_path = Path(image_path)
return image_path.parent / f"{image_path.stem}-depth.png"

# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 and taken from harubaru's implementation https://github.com/harubaru/waifu-diffusion
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]

self.decay = decay
self.optimization_step = 0

def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)

@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)

self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)

for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
s_param.copy_(param)

torch.cuda.empty_cache()

def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)

def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
2 changes: 1 addition & 1 deletion scripts/windows_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def check_versions():
status = shutil.copy2(src_file, cudnn_dest)
if status:
print("Copied CUDNN 8.6 files to destination")
d_commit = '8178c84'
d_commit = 'f727863'
diffusers_cmd = f"git+https://github.com/huggingface/diffusers.git@{d_commit}#egg=diffusers --force-reinstall"
run(f'"{python}" -m pip install {diffusers_cmd}', f"Installing Diffusers {d_commit} commit", "Couldn't install diffusers")
#install requirements file
Expand Down