Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added max-of-n #2

Merged
merged 2 commits into from
Jan 7, 2024
Merged

added max-of-n #2

merged 2 commits into from
Jan 7, 2024

Conversation

euanong
Copy link
Collaborator

@euanong euanong commented Jan 6, 2024

No description provided.

Copy link
Owner

@JasonGross JasonGross left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial review comments

gbmi/__init__.py Show resolved Hide resolved
gbmi/analysis_tools/decomp.py Show resolved Hide resolved
gbmi/analysis_tools/l1h1.py Show resolved Hide resolved
gbmi/verification_tools/general.py Show resolved Hide resolved
Comment on lines +34 to +53
assert W_pos.shape == (
n_ctx,
d_model,
), f"W_pos.shape = {W_pos.shape} != {(n_ctx, d_model)} = (n_ctx, d_model)"
assert W_Q.shape == (
1,
1,
d_model,
d_model,
), f"W_Q.shape = {W_Q.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)"
assert W_K.shape == (
1,
1,
d_model,
d_model,
), f"W_K.shape = {W_K.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)"
assert W_E.shape == (
d_vocab,
d_model,
), f"W_E.shape = {W_E.shape} != {(d_vocab, d_model)} = (d_vocab, d_model)"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't suppose there's any way to tell the code formatter to compress the shape checking?

My ideal here would be to annotate the assignment of W_pos, etc, with types that include tensor dimensions, and import some package for runtime type checking. Do you know if this is possible with, e.g., jaxtyping, torchtyping, typeguard, beartype, mypy, etc? (Maybe cf patrick-kidger/jaxtyping#153

@dataclass
class MaxOfN(ExperimentConfig):
# Max of n (iterable dataset)
n_train_samples: Optional[int] = None # if none, infinite dataset
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The dataset isn't actually infinite, though, right? Is there a better descriptor?

Comment on lines +67 to +79
simpler_cfg = HookedTransformerConfig(
d_model=config.d_model,
n_layers=config.n_layers,
n_heads=config.n_heads,
d_head=config.d_head,
n_ctx=config.n_ctx,
d_vocab=config.d_vocab,
seed=config.seed,
attn_only=True,
normalization_type=None,
# device=default_device(deterministic=config.deterministic),
)
model = HookedTransformer(simpler_cfg)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add config for float32 vs float64?

) -> Float[Tensor, ""]:
logits = logits[:, -1, :]
true_maximum = torch.max(tokens, dim=1)[0]
log_probs = logits.log_softmax(-1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to do the more accurate (currently untested)

Suggested change
log_probs = logits.log_softmax(-1)
# log_softmax is only around to around 2e-7, cf https://github.com/pytorch/pytorch/issues/113708
# we can get better precision by using log1p
logits_max_idxs = logits.argmax(dim=-1, keepdim=True)
logits_centered = logits - logits.gather(dim=-1, index=logits_max_idxs)
logits_exp = logits_centered.exp()
# logits_exp[max] will be 1, so we can zero it and use log1p(x) = log(1 + x)
logits_exp.scatter_(dim=-1, index=logits_max_idxs, 0)
log_probs = logits_centered - logits_exp.sum(dim=-1, keepdim=True).log1p()

Comment on lines +127 to +128
lr=self.config.optimizer_kwargs["lr"],
betas=self.config.optimizer_kwargs["betas"],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More compactly generalizable alternative (pick whichever seems better to you):

Suggested change
lr=self.config.optimizer_kwargs["lr"],
betas=self.config.optimizer_kwargs["betas"],
**{k:self.config.optimizer_kwargs[k] for k in ("lr", "betas")}

super().__init__(config)
self.config = config
self.seq_len = config.n_ctx
self.dataset_seed = config.seed * 10 + 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this seed? (Leave a comment about how (non)arbitrary this is)

Copy link
Owner

@JasonGross JasonGross left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial review comments

Comment on lines +221 to +223
if __name__ == "__main__":
print("Training model:", MAX_OF_10_CONFIG)
train_or_load_model(MAX_OF_10_CONFIG, force="train")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be in exp_max_of_n/train.py, maybe in exp_max_of_n/train_max_of_10.py? Or else we should take command line arguments (argparse? get chatgpt to quickly write up the argparse code?) for --max-of {2|10} [--force-train]


def train_or_load_model(
config: Config,
force: Optional[Literal["train", "load"]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to use enums or global constants (TRAIN = "train") here rather than strings, to make it harder to typo?

test_metrics: Sequence[Mapping[str, float]]


def _load_model(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also (want) to do __all__ = [...] to control import behavior?

elif unit == "epochs":
trainer_args = {"max_epochs": n}
else:
raise ValueError
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError
raise ValueError(f"Invalid unit {unit}")

def train_or_load_model(
config: Config,
force: Optional[Literal["train", "load"]] = None,
save_to: Optional[Literal["disk", "disk_and_wandb"]] = "disk_and_wandb",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to make this a bitmask instead, a la re.match's flags argument?

return "cuda" if torch.cuda.is_available() and not deterministic else "cpu"


T = TypeVar("T")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

from numpy.random import Generator
from torch import Tensor

PROJECT_ROOT = Path(__file__).parent.parent.parent
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fragile, especially since it goes outside the top-level package, and with respect to refactoring. Maybe we want to have a get_git_project_root? Or ensure it's a subdirectory of the package?

from typing import Any
from typing import Dict

# Implemented for https://github.com/lemon24/reader/issues/179
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you write this code @euanong ? Do we need to stick some license notice at the top?

@JasonGross
Copy link
Owner

I'm going to merge this now so we can build on it, code review can be fixed in another PR

@JasonGross JasonGross merged commit 7e3e96c into main Jan 7, 2024
18 checks passed
@JasonGross JasonGross deleted the eo-6-jan branch January 7, 2024 13:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants