From 4eaaa6cf68fc934aa65fe89b8933db658c7f48c7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 7 May 2024 10:37:06 +0800 Subject: [PATCH] adapted pretrained model to training (#371) --- opensora/models/stdit/stdit2.py | 18 +++++++++++++++++- scripts/train.py | 3 +-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index afcc0c12..5de17694 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn as nn +import os from einops import rearrange from rotary_embedding_torch import RotaryEmbedding from timm.models.layers import DropPath @@ -23,6 +24,7 @@ ) from opensora.registry import MODELS from transformers import PretrainedConfig, PreTrainedModel +from opensora.utils.ckpt_utils import load_checkpoint class STDiT2Block(nn.Module): @@ -502,8 +504,22 @@ def _basic_init(module): @MODELS.register_module("STDiT2-XL/2") def STDiT2_XL_2(from_pretrained=None, **kwargs): if from_pretrained is not None: - model = STDiT2.from_pretrained(from_pretrained, **kwargs) + if os.path.isdir(from_pretrained) or os.path.isfile(from_pretrained): + # if it is a directory or a file, we load the checkpoint manually + config = STDiT2Config( + depth=28, + hidden_size=1152, + patch_size=(1, 2, 2), + num_heads=16, **kwargs + ) + model = STDiT2(config) + load_checkpoint(model, from_pretrained) + return model + else: + # otherwise, we load the model from hugging face hub + return STDiT2.from_pretrained(from_pretrained) else: + # create a new model config = STDiT2Config( depth=28, hidden_size=1152, diff --git a/scripts/train.py b/scripts/train.py index aed7b452..bfb6d39c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -133,8 +133,7 @@ def main(): input_size=latent_size, in_channels=vae.out_channels, caption_channels=text_encoder.output_dim, - model_max_length=text_encoder.model_max_length, - dtype=dtype, + model_max_length=text_encoder.model_max_length ) model_numel, model_numel_trainable = get_model_numel(model) logger.info(