-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added max-of-n #2
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial review comments
assert W_pos.shape == ( | ||
n_ctx, | ||
d_model, | ||
), f"W_pos.shape = {W_pos.shape} != {(n_ctx, d_model)} = (n_ctx, d_model)" | ||
assert W_Q.shape == ( | ||
1, | ||
1, | ||
d_model, | ||
d_model, | ||
), f"W_Q.shape = {W_Q.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)" | ||
assert W_K.shape == ( | ||
1, | ||
1, | ||
d_model, | ||
d_model, | ||
), f"W_K.shape = {W_K.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)" | ||
assert W_E.shape == ( | ||
d_vocab, | ||
d_model, | ||
), f"W_E.shape = {W_E.shape} != {(d_vocab, d_model)} = (d_vocab, d_model)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't suppose there's any way to tell the code formatter to compress the shape checking?
My ideal here would be to annotate the assignment of W_pos
, etc, with types that include tensor dimensions, and import some package for runtime type checking. Do you know if this is possible with, e.g., jaxtyping, torchtyping, typeguard, beartype, mypy, etc? (Maybe cf patrick-kidger/jaxtyping#153
@dataclass | ||
class MaxOfN(ExperimentConfig): | ||
# Max of n (iterable dataset) | ||
n_train_samples: Optional[int] = None # if none, infinite dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The dataset isn't actually infinite, though, right? Is there a better descriptor?
simpler_cfg = HookedTransformerConfig( | ||
d_model=config.d_model, | ||
n_layers=config.n_layers, | ||
n_heads=config.n_heads, | ||
d_head=config.d_head, | ||
n_ctx=config.n_ctx, | ||
d_vocab=config.d_vocab, | ||
seed=config.seed, | ||
attn_only=True, | ||
normalization_type=None, | ||
# device=default_device(deterministic=config.deterministic), | ||
) | ||
model = HookedTransformer(simpler_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add config for float32 vs float64?
) -> Float[Tensor, ""]: | ||
logits = logits[:, -1, :] | ||
true_maximum = torch.max(tokens, dim=1)[0] | ||
log_probs = logits.log_softmax(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to do the more accurate (currently untested)
log_probs = logits.log_softmax(-1) | |
# log_softmax is only around to around 2e-7, cf https://github.com/pytorch/pytorch/issues/113708 | |
# we can get better precision by using log1p | |
logits_max_idxs = logits.argmax(dim=-1, keepdim=True) | |
logits_centered = logits - logits.gather(dim=-1, index=logits_max_idxs) | |
logits_exp = logits_centered.exp() | |
# logits_exp[max] will be 1, so we can zero it and use log1p(x) = log(1 + x) | |
logits_exp.scatter_(dim=-1, index=logits_max_idxs, 0) | |
log_probs = logits_centered - logits_exp.sum(dim=-1, keepdim=True).log1p() |
lr=self.config.optimizer_kwargs["lr"], | ||
betas=self.config.optimizer_kwargs["betas"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More compactly generalizable alternative (pick whichever seems better to you):
lr=self.config.optimizer_kwargs["lr"], | |
betas=self.config.optimizer_kwargs["betas"], | |
**{k:self.config.optimizer_kwargs[k] for k in ("lr", "betas")} |
super().__init__(config) | ||
self.config = config | ||
self.seq_len = config.n_ctx | ||
self.dataset_seed = config.seed * 10 + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this seed? (Leave a comment about how (non)arbitrary this is)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial review comments
if __name__ == "__main__": | ||
print("Training model:", MAX_OF_10_CONFIG) | ||
train_or_load_model(MAX_OF_10_CONFIG, force="train") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be in exp_max_of_n/train.py
, maybe in exp_max_of_n/train_max_of_10.py
? Or else we should take command line arguments (argparse? get chatgpt to quickly write up the argparse code?) for --max-of {2|10} [--force-train]
|
||
def train_or_load_model( | ||
config: Config, | ||
force: Optional[Literal["train", "load"]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to use enums or global constants (TRAIN = "train"
) here rather than strings, to make it harder to typo?
test_metrics: Sequence[Mapping[str, float]] | ||
|
||
|
||
def _load_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also (want) to do __all__ = [...]
to control import behavior?
elif unit == "epochs": | ||
trainer_args = {"max_epochs": n} | ||
else: | ||
raise ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError | |
raise ValueError(f"Invalid unit {unit}") |
def train_or_load_model( | ||
config: Config, | ||
force: Optional[Literal["train", "load"]] = None, | ||
save_to: Optional[Literal["disk", "disk_and_wandb"]] = "disk_and_wandb", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to make this a bitmask instead, a la re.match
's flags
argument?
return "cuda" if torch.cuda.is_available() and not deterministic else "cpu" | ||
|
||
|
||
T = TypeVar("T") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
from numpy.random import Generator | ||
from torch import Tensor | ||
|
||
PROJECT_ROOT = Path(__file__).parent.parent.parent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fragile, especially since it goes outside the top-level package, and with respect to refactoring. Maybe we want to have a get_git_project_root
? Or ensure it's a subdirectory of the package?
from typing import Any | ||
from typing import Dict | ||
|
||
# Implemented for https://github.com/lemon24/reader/issues/179 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you write this code @euanong ? Do we need to stick some license notice at the top?
I'm going to merge this now so we can build on it, code review can be fixed in another PR |
No description provided.