diff --git a/examples/selective2d/2d_train.py b/examples/selective2d/2d_train.py new file mode 100644 index 000000000..175a54c85 --- /dev/null +++ b/examples/selective2d/2d_train.py @@ -0,0 +1,478 @@ +""" +This training script updates NanoGPT to run with either TP, PP, or TP+PP (2D). +Usage: +gpurun4 torchrun --nproc-per-node 4 2d_train.py +""" + +import argparse +import os +import time + +import torch +import torch.distributed as dist + +from model import GPT, GPTConfig +from pippy.compile import compile_stage + +from pippy.IR import annotate_split_points, PipeSplitWrapper +from pippy.microbatch import sum_reducer, TensorChunkSpec + +from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PairwiseParallel, + parallelize_module, + RowwiseParallel, +) + + +def get_args(): + # default config values designed to train a gpt2 (124M) on OpenWebText + + def str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "t", "1"): + return True + elif v.lower() in ("false", "f", "0"): + return False + else: + raise ArgumentTypeError("Boolean expected.") + + # I/O + parser = argparse.ArgumentParser() + parser.add_argument("--out_dir", type=str, default="out") + parser.add_argument("--eval_interval", type=int, default=2000) + parser.add_argument("--log_interval", type=int, default=2) + parser.add_argument("--eval_iters", type=int, default=200) + parser.add_argument( + "--eval_only", type=str_to_bool, default=False + ) # if True, script exits right after the first eval + parser.add_argument( + "--always_save_checkpoint", type=str_to_bool, default=True + ) # if True, always save a checkpoint after each eval + parser.add_argument( + "--init_from", type=str, default="scratch" + ) # 'scratch', 'resume', 'gpt2*' + parser.add_argument("--train_iters", type=int, default=200000) + parser.add_argument("--seed", type=int, default=1337) + + # data + parser.add_argument( + "--dataset", type=str, default="shakespeare_char" + ) # "openwebtext" + parser.add_argument( + "--gradient_accumulation_steps", type=int, default=1 + ) # used to simulate larger batch sizes + parser.add_argument( + "--batch_size", type=int, default=12 + ) # if gradient_accumulation_steps > 1, this is the micro-batch size + parser.add_argument("--block_size", type=int, default=1024) + parser.add_argument("--vocab_size", type=int, default=50304) + + # model + parser.add_argument("--n_layer", type=int, default=12) + parser.add_argument("--n_head", type=int, default=12) + parser.add_argument("--n_embd", type=int, default=768) + parser.add_argument( + "--dropout", type=float, default=0.0 + ) # for pretraining 0 is good, for finetuning try 0.1+ + parser.add_argument("--bias", type=str_to_bool, default=False) + + # adamw optimizer + parser.add_argument( + "--learning_rate", type=float, default=4e-4 + ) # max learning rate + parser.add_argument( + "--max_iters", type=int, default=600000 + ) # total number of training iterations + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument( + "--grad_clip", type=float, default=1.0 + ) # clip gradients at this value, or disable if == 0.0 + parser.add_argument( + "--decay_lr", type=str_to_bool, default=True + ) # whether to decay the learning rate + parser.add_argument("--warmup_iters", type=int, default=2000) + parser.add_argument("--lr_decay_iters", type=int, default=600000) + parser.add_argument( + "--min_lr", type=float, default=6e-5 + ) # minimum learning rate + + # distributed + parser.add_argument( + "--backend", type=str, default="nccl" + ) # 'nccl', 'gloo', etc. + parser.add_argument( + "--compile", type=str_to_bool, default=False + ) # use PyTorch 2.0 to compile the model to be faster + parser.add_argument("--rank", type=int, default=int(os.environ["RANK"])) + parser.add_argument( + "--local_rank", type=int, default=int(os.environ["LOCAL_RANK"]) + ) + parser.add_argument( + "--world_size", type=int, default=int(os.environ["WORLD_SIZE"]) + ) + parser.add_argument( + "--device", type=str, default=f"cuda:{os.environ['LOCAL_RANK']}" + ) + parser.add_argument( + "--master_process", + type=str_to_bool, + default=bool(os.environ["RANK"] == 0), + ) + parser.add_argument("--tp_size", type=int, default=2) + parser.add_argument("--pp_size", type=int, default=2) + + parser.add_argument("--debug", dest="debug", action="store_true") + + args = parser.parse_args() + + return args + + +def rank_print(x): + _rank = os.getenv("RANK") + if _rank == "0": + print(x) + + +def get_rand(args): + x = torch.randint( + 0, + args.vocab_size, + (args.batch_size, args.block_size), + device=args.device, + ) + y = torch.randint( + 0, + args.vocab_size, + (args.batch_size, args.block_size), + device=args.device, + ) + return x, y + + +def tp_attention(model, name, mesh, tp_dim=0, q="q", k="k", v="v", o="c_proj"): + layer = model.get_submodule(name) + parallelize_module( + layer, + mesh, + { + q: ColwiseParallel(), + k: ColwiseParallel(), + v: ColwiseParallel(), + o: RowwiseParallel(), + }, + tp_mesh_dim=tp_dim, + ) + + return model + + +def tp_mlp(model, name, mesh, tp_dim=0, mlp="mlp"): + layer = model.get_submodule(name) + parallelize_module( + layer, mesh, {mlp: PairwiseParallel()}, tp_mesh_dim=tp_dim + ) + + return model + + +def tp(model, n_layer, mesh, offset=0, tp_dim=0): + for i in range(n_layer): + block = model.get_submodule(f"transformer.h.{i + offset}") + parallelize_module( + block, + mesh, + { + "attn.q": ColwiseParallel(), + "attn.k": ColwiseParallel(), + "attn.v": ColwiseParallel(), + "attn.c_proj": RowwiseParallel(), + "mlp": PairwiseParallel(), + }, + tp_mesh_dim=tp_dim, + ) + + return model + + +def pp(model, pp_device_mesh, args): + pp_chunks = args.world_size + pp_groups = pp_device_mesh.get_dim_groups()[0] + + output_chunk_spec = (TensorChunkSpec(0), sum_reducer) + stage = compile_stage( + model, + args.rank, + args.world_size, + pp_chunks, + pp_device_mesh, + pp_groups, + example_inputs=[X, Y], + output_chunk_spec=output_chunk_spec, + ) + + print(f"[Rank{_rank}] {stage.submod.print_readable()}") + return model, stage + + +def pp_and_tp(model, mesh, args): + """ + Apply TP and PP to all layers in a model + This function assumes the model is already cut manually + """ + pp_dim, tp_dim = 0, 1 + pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size + pp_groups = mesh.get_dim_groups()[pp_dim] + + # TP + tp(model, args.n_layer, mesh, 0, tp_dim) + + X, Y = get_rand(args) + + # PP + stage = compile_stage( + model, + pp_rank, + args.world_size, + args.pp_size, + args.device, + pp_groups, + example_inputs=[X, Y], + ) + + return model, stage + + +def even_cut(model, args, pp_size): + """ + Evenly cut a model into pp_size stages + """ + cut = {} + cutpoint = args.n_layer // pp_size + for i in range(args.n_layer): + name = f"transformer.h.{i}" + if i > 0 and i % cutpoint == 0: + cut[name] = PipeSplitWrapper.SplitPoint.BEGINNING # or END + + annotate_split_points(model, cut) + + +def after_ar_cut(model, args, pp_size): + """ + Cut a model right after AllReduce happens + """ + cut = {} + cutpoint = args.n_layer // pp_size + for i in range(args.n_layer): + name = f"transformer.h.{i}" + if i != args.n_layer - 1 and i % cutpoint == cutpoint - 1: + cut[f"{name}.mlp.dropout"] = PipeSplitWrapper.SplitPoint.BEGINNING + + annotate_split_points(model, cut) + + +def pp_and_tp_selective( + model, mesh, args, tp_attn_layers=None, tp_mlp_layers=None, cut_fn=even_cut +): + """ + Apply pipeline parallelism and tensor parallelism to a model. + """ + + pp_dim, tp_dim = 0, 1 + pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size + pp_groups = mesh.get_dim_groups()[pp_dim] + + # TP + # Apply TP to layers if layer_id is in tp_attn / tp_mlp + tp_attn_layers = ( + list(range(args.n_layer)) if tp_attn_layers is None else tp_attn_layers + ) + tp_mlp_layers = ( + list(range(args.n_layer)) if tp_mlp_layers is None else tp_mlp_layers + ) + for i in range(args.n_layer): + name = f"transformer.h.{i}" + att = tp_attention(model, f"{name}.attn", mesh, tp_dim) + mlp = tp_mlp(model, f"{name}", mesh, tp_dim) + + X, Y = get_rand(args) + + # PP + cut_fn(model, args, args.pp_size) + stage = compile_stage( + model, + pp_rank, + args.world_size, + args.pp_size, + args.device, + pp_groups, + example_inputs=[X, Y], + ) + + return model, stage + + +def pp_tp_train(stage, mesh, args): + pp_dim, tp_dim = 0, 1 + pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size + pp_groups = mesh.get_dim_groups()[pp_dim] + + train_iters = 10 if args.debug else args.train_iters + optimizer = torch.optim.AdamW( + stage.submod.parameters(), lr=args.learning_rate + ) + local_iter_num = 0 + iter_time = 0.0 + while local_iter_num < train_iters: + optimizer.zero_grad() + t0 = time.perf_counter() + X, Y = get_rand(args) + if pp_rank == 0: + out = stage(X) + elif pp_rank == args.pp_size - 1: + out = stage(Y) + else: + out = stage() + optimizer.step() + t1 = time.perf_counter() + dt = t1 - t0 + local_iter_num += 1 + iter_time += dt + + return local_iter_num, iter_time + + +def pp_train(stage, args): + train_iters = 10 if args.debug else args.train_iters + optimizer = torch.optim.AdamW( + stage.submod.parameters(), lr=args.learning_rate + ) + local_iter_num = 0 + iter_time = 0.0 + while local_iter_num < train_iters: + optimizer.zero_grad() + t0 = time.perf_counter() + X, Y = get_rand(args) + if args.rank == 0: + out = stage(X) + elif args.rank == args.world_size - 1: + out = stage(Y) + else: + out = stage() + optimizer.step() + t1 = time.perf_counter() + dt = t1 - t0 + local_iter_num += 1 + iter_time += dt + + return local_iter_num, iter_time + + +def tp_train(): + local_iter_num = 0 + iter_time = 0.0 + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + while local_iter_num < train_iters: + optimizer.zero_grad(set_to_none=True) + t0 = time.perf_counter() + X, Y = get_rand(args) + logits, loss = model(X, Y) + loss.backward() + optimizer.step() + torch.distributed.barrier() + t1 = time.perf_counter() + dt = t1 - t0 + lossf = loss.item() + rank_print( + f"iter {local_iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms" + ) + local_iter_num += 1 + iter_time += dt + + return local_iter_num, iter_time + + +if __name__ == "__main__": + _multi_gpu = int(os.environ.get("RANK", -1)) != -1 # verify distributed run + assert ( + _multi_gpu + ), "this config assumes distributed setup - multi-gpu not ready here." + + args = get_args() + + device_type = ( + "cuda" if "cuda" in args.device else "cpu" + ) # for later use in torch.autocast + torch.cuda.set_device(args.device) + + dist.init_process_group( + backend=args.backend, rank=args.rank, world_size=args.world_size + ) + + if args.master_process: + os.makedirs(args.out_dir, exist_ok=True) + + torch.manual_seed(args.seed) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + torch.backends.cuda.enable_mem_efficient_sdp(enabled=False) + + # init these up here, can override if init_from='resume' (i.e. from a checkpoint) + iter_num = 0 + best_val_loss = 1e9 + + # model init + model_args = dict( + n_layer=args.n_layer, + n_head=args.n_head, + n_embd=args.n_embd, + block_size=args.block_size, + bias=args.bias, + vocab_size=None, + dropout=args.dropout, + ) # start with model_args from command line + + # init a new model from scratch + rank_print("Initializing a new model from scratch") + + oned_mesh = DeviceMesh(device_type, list(range(args.world_size))) + twod_mesh = DeviceMesh( + device_type=device_type, + mesh=torch.arange(0, args.world_size).view(-1, args.tp_size), + ) + + model_args["vocab_size"] = args.vocab_size + + gptconf = GPTConfig(**model_args) + model = GPT(twod_mesh, gptconf, args.device, args.pp_size) + model.to(args.device) + + _current_model_params = model.get_num_params() / 1e6 + + # model = tp(model, args.n_layer, oned_mesh) + # model, stage = pp(model, oned_mesh, args) + # model, stage = pp_and_tp(model, twod_mesh, args) + model, stage = pp_and_tp_selective(model, twod_mesh, args) + + # iter_count, iter_time = pp_train(stage, args) + iter_count, iter_time = pp_tp_train(stage, twod_mesh, args) + + # display run stats + rank_print(f"\nTraining completed.\n") + + gpu_type = torch.cuda.get_device_name(0) + gpu_count = dist.get_world_size() + rank_print(f"\n----- Performance Stats --------\n") + rank_print(f"\nModel Size: {_current_model_params:.2f}M") + rank_print(f"Run completed with {gpu_count} gpus, of type {gpu_type}") + iter_avg = round(iter_time / iter_count, 4) + rank_print( + f"Avg iter speed (in seconds): {iter_avg}, with {iter_count} iterations averaged.\n" + ) + + dist.destroy_process_group() diff --git a/examples/selective2d/model.py b/examples/selective2d/model.py new file mode 100644 index 000000000..3bfd07d25 --- /dev/null +++ b/examples/selective2d/model.py @@ -0,0 +1,540 @@ +# Original code from https://github.com/karpathy/nanoGPT +""" +Full definition of a GPT Language Model, all of it in this single file. +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" +import inspect + +import math +import os +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) +def new_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, mesh, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + self.mesh = mesh + + def forward(self, input): + return F.layer_norm( + input, self.weight.shape, self.weight, self.bias, 1e-5 + ) + + +class CausalSelfAttention(nn.Module): + def __init__(self, mesh, config): + super().__init__() + tp_size = mesh.mesh.size(0) + assert config.n_head % tp_size == 0 + assert config.n_embd % config.n_head == 0 + self.mesh = mesh + self.tp_size = tp_size + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear( + config.n_embd, 3 * config.n_embd, bias=config.bias + ) + self.q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.k = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.v = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = ( + hasattr(torch.nn.functional, "scaled_dot_product_attention") + and self.dropout == 0.0 + ) + + if not self.flash: + print( + "WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0" + ) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.block_size = config.block_size + self.register_buffer( + "bias", + torch.tril( + torch.ones(config.block_size, config.block_size) + ).view(1, 1, config.block_size, config.block_size), + ) + + def forward(self, x): + ( + B, + T, + C, + ) = ( + x.size() + ) # batch size, sequence length, embedding dimensionality (n_embd) + + def print0(msg): + if os.getenv("RANK") == "0": + print(msg) + + channel_head_size = C // self.n_head + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = ( + self.q(x) + .split(self.n_embd // self.tp_size, dim=2)[0] + .view(B, T, self.n_head // self.tp_size, C // self.n_head) + .transpose(1, 2) + ) # (B, nh, T, hs) + k = ( + self.k(x) + .split(self.n_embd // self.tp_size, dim=2)[0] + .view(B, T, self.n_head // self.tp_size, C // self.n_head) + .transpose(1, 2) + ) # (B, nh, T, hs) + v = ( + self.v(x) + .split(self.n_embd // self.tp_size, dim=2)[0] + .view(B, T, self.n_head // self.tp_size, C // self.n_head) + .transpose(1, 2) + ) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True + ) + else: + # manual implementation of attention + from torch.distributed._tensor import ( + DeviceMesh, + distribute_tensor, + Replicate, + Shard, + ) + + mesh = DeviceMesh("cuda", list(range(2))) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C // self.tp_size) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear( + config.n_embd, 4 * config.n_embd, bias=config.bias + ) + self.gelu = nn.GELU() + self.c_proj = nn.Linear( + 4 * config.n_embd, config.n_embd, bias=config.bias + ) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, mesh, config): + super().__init__() + self.ln_1 = LayerNorm(mesh, config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(mesh, config) + self.ln_2 = LayerNorm(mesh, config.n_embd, bias=config.bias) + self.mlp = MLP(config) + self.mesh = mesh + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + def __init__(self, mesh, config, device, pp_size=2): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + self.mesh = mesh + self.pp_size = pp_size + self.device = device + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList( + [Block(mesh, config) for _ in range(config.n_layer)] + ), + ln_f=LayerNorm(mesh, config.n_embd, bias=config.bias), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate + self.transformer.wte.weight = ( + self.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) + + # report number of parameters + print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, idx, targets=None): + # device = idx.device + # b, t = idx.size() + # assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # WARNING: t needs to actual sequence length, shape should be (1,t) + pos = torch.arange( + 0, self.config.block_size, dtype=torch.long, device=self.device + ).unsqueeze(0) + + # forward the GPT model itself + tok_emb = self.transformer.wte( + idx + ) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-1, + ) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head( + x[:, [-1], :] + ) # note: using list [-1] to preserve the time dim + loss = None + + return logits, loss + + def crop_block_size(self, block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert block_size <= self.config.block_size + self.config.block_size = block_size + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:block_size] + ) + for block in self.transformer.h: + block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] + + @classmethod + def from_pretrained(cls, model_type, override_args=None): + assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} + override_args = override_args or {} # default to empty dict + # only dropout can be overridden see more notes below + assert all(k == "dropout" for k in override_args) + from transformers import GPT2LMHeadModel + + print("loading weights from pretrained gpt: %s" % model_type) + + # n_layer, n_head and n_embd are determined from model_type + config_args = { + "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params + "gpt2-medium": dict( + n_layer=24, n_head=16, n_embd=1024 + ), # 350M params + "gpt2-large": dict( + n_layer=36, n_head=20, n_embd=1280 + ), # 774M params + "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[model_type] + print("forcing vocab_size=50257, block_size=1024, bias=True") + config_args[ + "vocab_size" + ] = 50257 # always 50257 for GPT model checkpoints + config_args[ + "block_size" + ] = 1024 # always 1024 for GPT model checkpoints + config_args["bias"] = True # always True for GPT model checkpoints + # we can override the dropout rate, if desired + if "dropout" in override_args: + print(f"overriding dropout rate to {override_args['dropout']}") + config_args["dropout"] = override_args["dropout"] + # create a from-scratch initialized minGPT model + config = GPTConfig(**config_args) + model = GPT(config) + sd = model.state_dict() + sd_keys = sd.keys() + sd_keys = [ + k for k in sd_keys if not k.endswith(".attn.bias") + ] # discard this mask / buffer, not a param + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + sd_keys_hf = sd_hf.keys() + sd_keys_hf = [ + k for k in sd_keys_hf if not k.endswith(".attn.masked_bias") + ] # ignore these, just a buffer + sd_keys_hf = [ + k for k in sd_keys_hf if not k.endswith(".attn.bias") + ] # same, just the mask (buffer) + transposed = [ + "attn.c_attn.weight", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_proj.weight", + ] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(sd_keys_hf) == len( + sd_keys + ), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers( + self, weight_decay, learning_rate, betas, device_type + ): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = ( + torch.nn.LayerNorm, + LayerNorm, + torch.nn.Embedding, + ) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + # random note: because named_modules and named_parameters are recursive + # we will see the same tensors p many many times. but doing it this way + # allows us to know which parent module any tensor p belongs to... + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance( + m, whitelist_weight_modules + ): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance( + m, blacklist_weight_modules + ): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they + # will appear in the no_decay and decay sets respectively after the above. + # In addition, because named_parameters() doesn't return duplicates, it + # will only return the first occurence, key'd by 'transformer.wte.weight', below. + # so let's manually remove 'lm_head.weight' from decay set. This will include + # this tensor into optimization via transformer.wte.weight only, and not decayed. + decay.remove("lm_head.weight") + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % ( + str(inter_params), + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters %s were not separated into either decay/no_decay set!" + % (str(param_dict.keys() - union_params),) + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + # new PyTorch nightly has a new 'fused' option for AdamW that is much faster + use_fused = (device_type == "cuda") and ( + "fused" in inspect.signature(torch.optim.AdamW).parameters + ) + use_fused = False # YEONJU + print(f"using fused AdamW: {use_fused}") + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW( + optim_groups, lr=learning_rate, betas=betas, **extra_args + ) + + return optimizer + + def estimate_mfu(self, num_params, fwdbwd_per_iter, dt): + """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" + # first estimate the number of flops we do per iteration. + # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 + # N = self.get_num_params() + N = num_params + cfg = self.config + tp_size = 2 + actual_head = cfg.n_head // tp_size + # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size + L, H, Q, T = ( + cfg.n_layer, + actual_head, + cfg.n_embd // actual_head, + cfg.block_size, + ) + flops_per_token = 6 * N + 12 * L * H * Q * T + flops_per_fwdbwd = flops_per_token * T + flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter + # express our flops throughput as ratio of A100 bfloat16 peak flops + flops_achieved = flops_per_iter * (1.0 / dt) # per second + # flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + # mfu = flops_achieved / flops_promised + flops_promised = 125e12 # A10 TFlops .... 312e12 A100 GPU bfloat16 peak flops is 312 TFLOPS + mfu = (flops_achieved / flops_promised) / tp_size + return mfu + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = ( + idx + if idx.size(1) <= self.config.block_size + else idx[:, -self.config.block_size :] + ) + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx