Skip to content

Commit

Permalink
(closes xjdr-alt#31) adds a base sampler protocol, reorganizes sample…
Browse files Browse the repository at this point in the history
…rs and main
  • Loading branch information
qdbp committed Oct 8, 2024
1 parent 2f7b4ad commit c601285
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 218 deletions.
91 changes: 62 additions & 29 deletions entropix/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Generic

import jax
import jax.numpy as jnp
import tyro

from entropix.config import LLAMA_1B_PARAMS
from entropix.config import LLAMA_1B_PARAMS, ModelParams
from entropix.kvcache import KVCache
from entropix.model import xfmr
from entropix.sampler import SamplerConfig, sample
from entropix.prompts import create_prompts_from_csv, prompt
from entropix.sampler import sample
from entropix.samplers import ST, Cfg_contra, EntropySampler
from entropix.samplers.baseline_sampler import SamplerConfig as BaselineSamplerConfig
from entropix.samplers.baseline_sampler import sample as baseline_sampler
from entropix.tokenizer import Tokenizer
from entropix.weights import load_weights
from entropix.weights import XfmrWeights, load_weights

DEFAULT_WEIGHTS_PATH = Path(__file__).parent / "../weights"

DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights'

def apply_scaling(freqs: jax.Array):
SCALE_FACTOR = 8
Expand All @@ -36,13 +40,15 @@ def scale_mid(_):
wavelen < high_freq_wavelen,
lambda _: freq,
lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None),
None
None,
)

return jax.vmap(scale_freq)(freqs)


def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array:
def precompute_freqs_cis(
dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32
) -> jax.Array:
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
if use_scaled:
freqs = apply_scaling(freqs)
Expand All @@ -54,55 +60,82 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled
def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
if seqlen > 1:
mask = jnp.full((seqlen, seqlen), float('-inf'))
mask = jnp.full((seqlen, seqlen), float("-inf"))
mask = jnp.triu(mask, k=1)
mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32)
return mask


def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath('1B-Instruct')):
model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights(weights_path.absolute())
tokenizer = Tokenizer('entropix/tokenizer.model')
# Create the batch of tokens
@dataclass(kw_only=True)
class TokenGenerator(Generic[Cfg_contra, ST]):
weights: XfmrWeights
model_params: ModelParams
tokenizer: Tokenizer
sampler: EntropySampler[Cfg_contra, ST]
sampler_cfg: Cfg_contra

# Create the batch of tokens
def generate(xfmr_weights, model_params, tokens):
def generate_from_prompt(self, init_tokens) -> Generator[str, None, None]:
gen_tokens = None
cur_pos = 0
tokens = jnp.array([tokens], jnp.int32)
tokens = jnp.array([init_tokens], jnp.int32)
bsz, seqlen = tokens.shape
attn_mask = build_attn_mask(seqlen, cur_pos)
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
mp = self.model_params
freqs_cis = precompute_freqs_cis(mp.head_dim, mp.max_seq_len, mp.rope_theta, mp.use_scaled_rope)
kvcache = KVCache.new(mp.n_layers, bsz, mp.max_seq_len, mp.n_local_kv_heads, mp.head_dim)
logits, kvcache, _, _ = xfmr(self.weights, mp, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
gen_tokens = next_token
print(tokenizer.decode([next_token.item()]), end='', flush=True)

yield self.tokenizer.decode([next_token.item()])

cur_pos = seqlen
stop = jnp.array([128001, 128008, 128009])
sampler_cfg = SamplerConfig()
state: ST | None = None
while cur_pos < 8192:
cur_pos += 1
logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg)
logits, kvcache, scores, _ = xfmr(
self.weights, mp, next_token, cur_pos, freqs_cis[cur_pos : cur_pos + 1], kvcache
)
next_token, state = self.sampler(gen_tokens, logits, scores, cfg=self.sampler_cfg, state=state)
gen_tokens = jnp.concatenate((gen_tokens, next_token))
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
yield self.tokenizer.decode(next_token.tolist()[0])
if jnp.isin(next_token, stop).any():
break

csv_path = Path('entropix/data/prompts.csv')

def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath("1B-Instruct")):
model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights(weights_path.absolute())
# TODO(qdbp) make tokenizer into arg as well
tokenizer = Tokenizer("entropix/tokenizer.model")

csv_path = Path("entropix/data/prompts.csv")
prompts = create_prompts_from_csv(csv_path)
PROMPT_TEST = False

# TODO(qdbp) make these configurable once more are implemented
sampler = baseline_sampler
sampler_cfg = BaselineSamplerConfig()

generator = TokenGenerator(
weights=xfmr_weights, model_params=model_params, tokenizer=tokenizer, sampler=sampler, sampler_cfg=sampler_cfg
)

if PROMPT_TEST:
for p in prompts:
print(p)
tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special="all")
for token in generator.generate_from_prompt(tokens):
print(token, end="", flush=True)

else:
print(prompt)
tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special="all")
for token in generator.generate_from_prompt(tokens):
print(token, end="", flush=True)


if __name__ == '__main__':
if __name__ == "__main__":
tyro.cli(main)
189 changes: 0 additions & 189 deletions entropix/sampler.py

This file was deleted.

44 changes: 44 additions & 0 deletions entropix/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Protocol, TypeVar

import jax

# TODO(qdbp) these type vars would look MUCH less ugly if we just
# bumped to 3.12 for the new non-fugly generics syntax and variance inference

# sampler config typevar
Cfg_contra = TypeVar("Cfg_contra", contravariant=True) # input only -> contravariant

# sampler state type variable
ST = TypeVar("ST") # i/o -> invariant


class EntropySampler(Protocol[Cfg_contra, ST]):
"""
A sampler is any object that can be called to perform a single sampling step (see Sampler.__call__)
Functions count.
"""

def __call__(
self,
gen_tokens: jax.Array,
logits: jax.Array,
attention_scores: jax.Array,
*,
cfg: Cfg_contra,
state: ST | None = None,
key: jax.Array = jax.random.PRNGKey(1337),
) -> tuple[jax.Array, ST]:
"""
Performs a single sampling step.
Args:
gen_tokens: Array of the current token context.
logits: Array of next token logits predicted by the model
attention_scores: Array of attention scores are returned by xfmr
cfg: class-specific configuration object encapsulating any other sampling parameters
Returns:
next token as jax.Array
"""
...
Loading

0 comments on commit c601285

Please sign in to comment.