diff --git a/entropix/main.py b/entropix/main.py index b8dcc42..9f37034 100644 --- a/entropix/main.py +++ b/entropix/main.py @@ -1,20 +1,24 @@ import math +from dataclasses import dataclass from pathlib import Path +from typing import Generator, Generic import jax import jax.numpy as jnp import tyro -from entropix.config import LLAMA_1B_PARAMS +from entropix.config import LLAMA_1B_PARAMS, ModelParams from entropix.kvcache import KVCache from entropix.model import xfmr -from entropix.sampler import SamplerConfig, sample from entropix.prompts import create_prompts_from_csv, prompt -from entropix.sampler import sample +from entropix.samplers import ST, Cfg_contra, EntropySampler +from entropix.samplers.baseline_sampler import SamplerConfig as BaselineSamplerConfig +from entropix.samplers.baseline_sampler import sample as baseline_sampler from entropix.tokenizer import Tokenizer -from entropix.weights import load_weights +from entropix.weights import XfmrWeights, load_weights + +DEFAULT_WEIGHTS_PATH = Path(__file__).parent / "../weights" -DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights' def apply_scaling(freqs: jax.Array): SCALE_FACTOR = 8 @@ -36,13 +40,15 @@ def scale_mid(_): wavelen < high_freq_wavelen, lambda _: freq, lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None), - None + None, ) return jax.vmap(scale_freq)(freqs) -def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array: +def precompute_freqs_cis( + dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32 +) -> jax.Array: freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) if use_scaled: freqs = apply_scaling(freqs) @@ -54,55 +60,82 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array: mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32) if seqlen > 1: - mask = jnp.full((seqlen, seqlen), float('-inf')) + mask = jnp.full((seqlen, seqlen), float("-inf")) mask = jnp.triu(mask, k=1) mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32) return mask -def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath('1B-Instruct')): - model_params = LLAMA_1B_PARAMS - xfmr_weights = load_weights(weights_path.absolute()) - tokenizer = Tokenizer('entropix/tokenizer.model') +# Create the batch of tokens +@dataclass(kw_only=True) +class TokenGenerator(Generic[Cfg_contra, ST]): + weights: XfmrWeights + model_params: ModelParams + tokenizer: Tokenizer + sampler: EntropySampler[Cfg_contra, ST] + sampler_cfg: Cfg_contra - # Create the batch of tokens - def generate(xfmr_weights, model_params, tokens): + def generate_from_prompt(self, init_tokens) -> Generator[str, None, None]: gen_tokens = None cur_pos = 0 - tokens = jnp.array([tokens], jnp.int32) + tokens = jnp.array([init_tokens], jnp.int32) bsz, seqlen = tokens.shape attn_mask = build_attn_mask(seqlen, cur_pos) - freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope) - kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim) - logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask) + mp = self.model_params + freqs_cis = precompute_freqs_cis(mp.head_dim, mp.max_seq_len, mp.rope_theta, mp.use_scaled_rope) + kvcache = KVCache.new(mp.n_layers, bsz, mp.max_seq_len, mp.n_local_kv_heads, mp.head_dim) + logits, kvcache, _, _ = xfmr(self.weights, mp, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask) next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32) gen_tokens = next_token - print(tokenizer.decode([next_token.item()]), end='', flush=True) + + yield self.tokenizer.decode([next_token.item()]) + cur_pos = seqlen stop = jnp.array([128001, 128008, 128009]) - sampler_cfg = SamplerConfig() + state: ST | None = None while cur_pos < 8192: cur_pos += 1 - logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache) - next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg) + logits, kvcache, scores, _ = xfmr( + self.weights, mp, next_token, cur_pos, freqs_cis[cur_pos : cur_pos + 1], kvcache + ) + next_token, state = self.sampler(gen_tokens, logits, scores, cfg=self.sampler_cfg, state=state) gen_tokens = jnp.concatenate((gen_tokens, next_token)) - print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True) + yield self.tokenizer.decode(next_token.tolist()[0]) if jnp.isin(next_token, stop).any(): break - csv_path = Path('entropix/data/prompts.csv') + +def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath("1B-Instruct")): + model_params = LLAMA_1B_PARAMS + xfmr_weights = load_weights(weights_path.absolute()) + # TODO(qdbp) make tokenizer into arg as well + tokenizer = Tokenizer("entropix/tokenizer.model") + + csv_path = Path("entropix/data/prompts.csv") prompts = create_prompts_from_csv(csv_path) PROMPT_TEST = False + # TODO(qdbp) make these configurable once more are implemented + sampler = baseline_sampler + sampler_cfg = BaselineSamplerConfig() + + generator = TokenGenerator( + weights=xfmr_weights, model_params=model_params, tokenizer=tokenizer, sampler=sampler, sampler_cfg=sampler_cfg + ) + if PROMPT_TEST: for p in prompts: print(p) - tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special='all') - generate(xfmr_weights, model_params, tokens) + tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special="all") + for token in generator.generate_from_prompt(tokens): + print(token, end="", flush=True) + else: print(prompt) - tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all') - generate(xfmr_weights, model_params, tokens) + tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special="all") + for token in generator.generate_from_prompt(tokens): + print(token, end="", flush=True) + -if __name__ == '__main__': +if __name__ == "__main__": tyro.cli(main) diff --git a/entropix/sampler.py b/entropix/sampler.py deleted file mode 100644 index e9b9420..0000000 --- a/entropix/sampler.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import Dict, Tuple - -import chex -import jax -import jax.numpy as jnp - -LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E - -@jax.jit -def calculate_varentropy_logsoftmax(logits: jnp.ndarray, axis: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" - log_probs = jax.nn.log_softmax(logits, axis=axis) - probs = jnp.exp(log_probs) - entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2 # Convert to base-2 - varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=axis) - return entropy, varentropy - -def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array: - """Samples one token from a multinomial distribution with sorted probabilities.""" - q = jax.random.exponential(key=key, shape=probs_sort.shape) - return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32) - -def _sample( logits: jax.Array, *, temperature: float | jax.Array, top_p: float | jax.Array, top_k: int | jax.Array, min_p: float | jax.Array, - key=jax.random.PRNGKey(1337),) -> jax.Array: - bsz = logits.shape[0] - logit = logits[:, -1] - probs = jax.nn.softmax(logit / temperature, axis=-1) - - # Apply min_p sampling - if min_p > 0.0: - p_max = jnp.max(probs, axis=-1, keepdims=True) - indices_to_remove = probs < (min_p * p_max) - logit = jnp.where(indices_to_remove, jnp.full_like(logit, float('-inf')), logit) - - # Apply top-k sampling - top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k) - probs_sort = jnp.flip(top_k_probs, axis=-1) - probs_idx = jnp.flip(top_k_indices, axis=-1) - probs_sum = jnp.cumsum(probs_sort, axis=-1) - # Apply top-p sampling - mask = jnp.where(probs_sum - probs_sort > top_p, 1.0, 0.0) - probs_sort = probs_sort * (1 - mask) - probs_sort = probs_sort / jnp.sum(probs_sort, axis=-1, keepdims=True) - next_token = multinomial_sample_one(probs_sort, key) - next_token_g = jnp.take_along_axis(probs_idx, next_token.reshape(bsz, 1), axis=-1) - return next_token_g.astype(jnp.int32) - -def calculate_metrics(logits: jnp.ndarray, attention_scores: jnp.ndarray) -> Dict[str, jnp.ndarray]: - entropy, varentropy = calculate_varentropy_logsoftmax(logits) - - attention_probs = jax.nn.softmax(attention_scores, axis=-1) - attn_entropy = -jnp.sum(attention_probs * jnp.log2(jnp.clip(attention_probs, 1e-10, 1.0)), axis=-1) - attn_varentropy = jnp.var(attn_entropy, axis=1) - - mean_attention = jnp.mean(attention_probs, axis=1) - agreement = jnp.mean(jnp.abs(attention_probs - mean_attention[:, None, :]), axis=(1, 2)) - - interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3)) - - return { - "logits_entropy": jnp.mean(entropy), - "logits_varentropy": jnp.mean(varentropy), - "attn_entropy": jnp.mean(attn_entropy), - "attn_varentropy": jnp.mean(attn_varentropy), - "agreement": jnp.mean(agreement), - "interaction_strength": interaction_strength - } - -@chex.dataclass(kw_only=True, frozen=True) -class SamplerConfig: - """ - Encapsulation of all available sampler hyperparameters. - - This should be a good starting point for baselining experiments. - """ - - temp: float = 0.666 - top_p: float = 0.90 - top_k: int = 27 - min_p: float = 0.03 # Turn this down to 0.01 to reduce the shoggoth - - low_ent_thresh: float = 0.1 - low_vent_thresh: float = 0.1 - med_ent_thresh: float = 3.0 - high_ent_thresh: float = 5.0 - high_vent_thresh: float = 5.0 - - # TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible - helv_attn_ent_offset: float = 1.3 - helv_attn_ent_coef: float = 0.2 - - lehv_interaction_strength_offset: float = 1.2 - lehv_interaction_strength_coef: float = 0.3 - - hehv_attn_ent_coef: float = 0.2 - hehv_attn_vent_offset: float = 2.0 - hehv_attn_vent_coef: float = 0.5 - - # TODO not convinced this should - n_adaptive_samples: int = 5 - - # Adaptive sampling parameters - ada_temp_logits: float = 0.3 - ada_temp_attn: float = 0.2 - ada_temp_agree: float = 0.2 - ada_top_p: float = 0.1 - ada_top_k_int: float = 0.3 - ada_top_k_agree: float = 0.2 - ada_min_p: float = 0.5 - ada_score_logits_ent: float = 0.1 - ada_score_attn_ent: float = 0.2 - ada_score_logits_vent: float = 0.3 - ada_score_attn_vent: float = 0.4 - ada_score_agree: float = 0.5 - ada_score_int: float = 0.6 - - -def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array, cfg: SamplerConfig, - clarifying_question_token: int = 2564, key=jax.random.PRNGKey(1337)) -> jax.Array: - - metrics = calculate_metrics(logits, attention_scores) - ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] - attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] - agreement = metrics["agreement"] - interaction_strength = metrics["interaction_strength"] - - # Low Entropy, Low Varentropy: "flowing with unspoken intent" - if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: - return jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32) - - # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" - elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh: - # Insert a clarifying question token if not already present - if not jnp.isin(gen_tokens[:,-1], clarifying_question_token).any(): - return jnp.array([[clarifying_question_token]]) - else: - # If we've just asked a question, sample with slightly higher temperature - temp_adj = cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy - return _sample(logits, temperature=min(1.5, cfg.temp * temp_adj), top_p=cfg.top_p, top_k=cfg.top_k, min_p=cfg.min_p, key=key) - - # Low Entropy, High Varentropy: "exploring forks in the path" - elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh: - temp_adj = cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength - top_k_adj = max(5, int(cfg.top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low - return _sample(logits, temperature=min(1.5, cfg.temp * temp_adj), top_p=cfg.top_p, top_k=top_k_adj, min_p=cfg.min_p, key=key) - - # High Entropy, High Varentropy: "resampling in the mist" - elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh: - # Use high temperature and adjusted top_p based on attention metrics - temp_adj = cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy - top_p_adj = max(0.5, cfg.top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high - return _sample(logits, temperature=max(2.0, cfg.temp * temp_adj), top_p=top_p_adj, top_k=cfg.top_k, min_p=cfg.min_p, key=key) - - # Middle ground: use adaptive sampling - else: - logits_uncertainty = metrics["logits_entropy"] + metrics["logits_varentropy"] - attn_uncertainty = metrics["attn_entropy"] + metrics["attn_varentropy"] - - temperature = cfg.temp * (1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * metrics["agreement"]) - top_p = jnp.clip(cfg.top_p * (1 + cfg.ada_top_p * metrics["attn_varentropy"]), 0.1, 1.0) - top_k = int(jnp.clip( - jnp.round(cfg.top_k * (1 + cfg.ada_top_k_int * metrics["interaction_strength"].item() - cfg.ada_top_k_agree * metrics["agreement"].item())), - a_min=1, - a_max=100 - )) - min_p = jnp.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5) - - keys = jax.random.split(key, cfg.n_adaptive_samples) - - samples = [] - for sample_key in keys: - sample = _sample(logits, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, key=sample_key) - samples.append(sample) - - def score_sample(sample): - log_prob = jnp.sum(jax.nn.log_softmax(logits) * jax.nn.one_hot(sample, logits.shape[-1])) - confidence_score = ( - (1 - metrics["logits_entropy"]) * cfg.ada_score_logits_ent + - (1 - metrics["attn_entropy"]) * cfg.ada_score_attn_ent + - (1 - metrics["logits_varentropy"]) * cfg.ada_score_logits_vent + - (1 - metrics["attn_varentropy"]) * cfg.ada_score_attn_vent + - metrics["agreement"] * cfg.ada_score_agree + - metrics["interaction_strength"] * cfg.ada_score_int - ) - return log_prob + confidence_score - - sample_scores = [score_sample(sample) for sample in samples] - best_sample_idx = jnp.argmax(jnp.array(sample_scores)) - return samples[best_sample_idx] diff --git a/entropix/samplers/__init__.py b/entropix/samplers/__init__.py new file mode 100644 index 0000000..7cd567d --- /dev/null +++ b/entropix/samplers/__init__.py @@ -0,0 +1,44 @@ +from typing import Protocol, TypeVar + +import jax + +# TODO(qdbp) these type vars would look MUCH less ugly if we just +# bumped to 3.12 for the new non-fugly generics syntax and variance inference + +# sampler config typevar +Cfg_contra = TypeVar("Cfg_contra", contravariant=True) # input only -> contravariant + +# sampler state type variable +ST = TypeVar("ST") # i/o -> invariant + + +class EntropySampler(Protocol[Cfg_contra, ST]): + """ + A sampler is any object that can be called to perform a single sampling step (see Sampler.__call__) + + Functions count. + """ + + def __call__( + self, + gen_tokens: jax.Array, + logits: jax.Array, + attention_scores: jax.Array, + *, + cfg: Cfg_contra, + state: ST | None = None, + key: jax.Array = jax.random.PRNGKey(1337), + ) -> tuple[jax.Array, ST]: + """ + Performs a single sampling step. + + Args: + gen_tokens: Array of the current token context. + logits: Array of next token logits predicted by the model + attention_scores: Array of attention scores are returned by xfmr + cfg: class-specific configuration object encapsulating any other sampling parameters + + Returns: + next token as jax.Array + """ + ... diff --git a/entropix/samplers/baseline_sampler.py b/entropix/samplers/baseline_sampler.py new file mode 100644 index 0000000..a61f5f0 --- /dev/null +++ b/entropix/samplers/baseline_sampler.py @@ -0,0 +1,276 @@ +from typing import Dict, Tuple + +import chex +import jax +import jax.numpy as jnp + +LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E + + +@jax.jit +def calculate_varentropy_logsoftmax( + logits: jnp.ndarray, axis: int = -1 +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" + log_probs = jax.nn.log_softmax(logits, axis=axis) + probs = jnp.exp(log_probs) + entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2 # Convert to base-2 + varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None]) ** 2, axis=axis) + return entropy, varentropy + + +def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array: + """Samples one token from a multinomial distribution with sorted probabilities.""" + q = jax.random.exponential(key=key, shape=probs_sort.shape) + return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32) + + +def _sample( + logits: jax.Array, + *, + temperature: float | jax.Array, + top_p: float | jax.Array, + top_k: int | jax.Array, + min_p: float | jax.Array, + key=jax.random.PRNGKey(1337), +) -> jax.Array: + bsz = logits.shape[0] + logit = logits[:, -1] + probs = jax.nn.softmax(logit / temperature, axis=-1) + + # Apply min_p sampling + if min_p > 0.0: + p_max = jnp.max(probs, axis=-1, keepdims=True) + indices_to_remove = probs < (min_p * p_max) + logit = jnp.where(indices_to_remove, jnp.full_like(logit, float("-inf")), logit) + + # Apply top-k sampling + top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k) + probs_sort = jnp.flip(top_k_probs, axis=-1) + probs_idx = jnp.flip(top_k_indices, axis=-1) + probs_sum = jnp.cumsum(probs_sort, axis=-1) + # Apply top-p sampling + mask = jnp.where(probs_sum - probs_sort > top_p, 1.0, 0.0) + probs_sort = probs_sort * (1 - mask) + probs_sort = probs_sort / jnp.sum(probs_sort, axis=-1, keepdims=True) + next_token = multinomial_sample_one(probs_sort, key) + next_token_g = jnp.take_along_axis(probs_idx, next_token.reshape(bsz, 1), axis=-1) + return next_token_g.astype(jnp.int32) + + +def calculate_metrics( + logits: jnp.ndarray, attention_scores: jnp.ndarray +) -> Dict[str, jnp.ndarray]: + entropy, varentropy = calculate_varentropy_logsoftmax(logits) + + attention_probs = jax.nn.softmax(attention_scores, axis=-1) + attn_entropy = -jnp.sum( + attention_probs * jnp.log2(jnp.clip(attention_probs, 1e-10, 1.0)), axis=-1 + ) + attn_varentropy = jnp.var(attn_entropy, axis=1) + + mean_attention = jnp.mean(attention_probs, axis=1) + agreement = jnp.mean( + jnp.abs(attention_probs - mean_attention[:, None, :]), axis=(1, 2) + ) + + interaction_strength = jnp.mean(jnp.abs(attention_scores), axis=(1, 2, 3)) + + return { + "logits_entropy": jnp.mean(entropy), + "logits_varentropy": jnp.mean(varentropy), + "attn_entropy": jnp.mean(attn_entropy), + "attn_varentropy": jnp.mean(attn_varentropy), + "agreement": jnp.mean(agreement), + "interaction_strength": interaction_strength, + } + + +@chex.dataclass(kw_only=True, frozen=True) +class SamplerConfig: + """ + Encapsulation of all available sampler hyperparameters. + + This should be a good starting point for baselining experiments. + """ + + temp: float = 0.666 + top_p: float = 0.90 + top_k: int = 27 + min_p: float = 0.03 # Turn this down to 0.01 to reduce the shoggoth + + low_ent_thresh: float = 0.1 + low_vent_thresh: float = 0.1 + med_ent_thresh: float = 3.0 + high_ent_thresh: float = 5.0 + high_vent_thresh: float = 5.0 + + # TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible + helv_attn_ent_offset: float = 1.3 + helv_attn_ent_coef: float = 0.2 + + lehv_interaction_strength_offset: float = 1.2 + lehv_interaction_strength_coef: float = 0.3 + + hehv_attn_ent_coef: float = 0.2 + hehv_attn_vent_offset: float = 2.0 + hehv_attn_vent_coef: float = 0.5 + + # TODO not convinced this should + n_adaptive_samples: int = 5 + + # Adaptive sampling parameters + ada_temp_logits: float = 0.3 + ada_temp_attn: float = 0.2 + ada_temp_agree: float = 0.2 + ada_top_p: float = 0.1 + ada_top_k_int: float = 0.3 + ada_top_k_agree: float = 0.2 + ada_min_p: float = 0.5 + ada_score_logits_ent: float = 0.1 + ada_score_attn_ent: float = 0.2 + ada_score_logits_vent: float = 0.3 + ada_score_attn_vent: float = 0.4 + ada_score_agree: float = 0.5 + ada_score_int: float = 0.6 + + +# implements EntropySampler[SamplerConfig, None] +def sample( + gen_tokens: jax.Array, + logits: jax.Array, + attention_scores: jax.Array, + *, + cfg: SamplerConfig, + state: None = None, + clarifying_question_token: int = 2564, + key=jax.random.PRNGKey(1337), + # the None in the tuple satisfies the EntropySampler protocol. + # the baseline sampler is currently stateless but this can change later +) -> tuple[jax.Array, None]: + metrics = calculate_metrics(logits, attention_scores) + ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] + attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"] + agreement = metrics["agreement"] + interaction_strength = metrics["interaction_strength"] + + # Low Entropy, Low Varentropy: "flowing with unspoken intent" + if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: + return jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32), None + + # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" + elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh: + # Insert a clarifying question token if not already present + if not jnp.isin(gen_tokens[:, -1], clarifying_question_token).any(): + return jnp.array([[clarifying_question_token]]), None + else: + # If we've just asked a question, sample with slightly higher temperature + temp_adj = ( + cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent + ) # Increase temperature based on attention entropy + return _sample( + logits, + temperature=min(1.5, cfg.temp * temp_adj), + top_p=cfg.top_p, + top_k=cfg.top_k, + min_p=cfg.min_p, + key=key, + ), None + + # Low Entropy, High Varentropy: "exploring forks in the path" + elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh: + temp_adj = ( + cfg.lehv_interaction_strength_offset + + cfg.lehv_interaction_strength_coef * interaction_strength + ) # Increase temperature based on interaction strength + top_k_adj = max( + 5, int(cfg.top_k * (1 + 0.5 * (1 - agreement))) + ) # Increase top_k when agreement is low + return _sample( + logits, + temperature=min(1.5, cfg.temp * temp_adj), + top_p=cfg.top_p, + top_k=top_k_adj, + min_p=cfg.min_p, + key=key, + ), None + + # High Entropy, High Varentropy: "resampling in the mist" + elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh: + # Use high temperature and adjusted top_p based on attention metrics + temp_adj = ( + cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent + ) # Increase temperature based on attention varentropy + top_p_adj = max( + 0.5, cfg.top_p - cfg.hehv_attn_ent_coef * attn_ent + ) # Decrease top_p when attention entropy is high + return _sample( + logits, + temperature=max(2.0, cfg.temp * temp_adj), + top_p=top_p_adj, + top_k=cfg.top_k, + min_p=cfg.min_p, + key=key, + ), None + + # Middle ground: use adaptive sampling + else: + logits_uncertainty = metrics["logits_entropy"] + metrics["logits_varentropy"] + attn_uncertainty = metrics["attn_entropy"] + metrics["attn_varentropy"] + + temperature = cfg.temp * ( + 1 + + cfg.ada_temp_logits * logits_uncertainty + + cfg.ada_temp_attn * attn_uncertainty + - cfg.ada_temp_agree * metrics["agreement"] + ) + top_p = jnp.clip( + cfg.top_p * (1 + cfg.ada_top_p * metrics["attn_varentropy"]), 0.1, 1.0 + ) + top_k = int( + jnp.clip( + jnp.round( + cfg.top_k + * ( + 1 + + cfg.ada_top_k_int * metrics["interaction_strength"].item() + - cfg.ada_top_k_agree * metrics["agreement"].item() + ) + ), + a_min=1, + a_max=100, + ) + ) + min_p = jnp.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5) + + keys = jax.random.split(key, cfg.n_adaptive_samples) + + samples = [] + for sample_key in keys: + sample = _sample( + logits, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + key=sample_key, + ) + samples.append(sample) + + def score_sample(sample): + log_prob = jnp.sum( + jax.nn.log_softmax(logits) * jax.nn.one_hot(sample, logits.shape[-1]) + ) + confidence_score = ( + (1 - metrics["logits_entropy"]) * cfg.ada_score_logits_ent + + (1 - metrics["attn_entropy"]) * cfg.ada_score_attn_ent + + (1 - metrics["logits_varentropy"]) * cfg.ada_score_logits_vent + + (1 - metrics["attn_varentropy"]) * cfg.ada_score_attn_vent + + metrics["agreement"] * cfg.ada_score_agree + + metrics["interaction_strength"] * cfg.ada_score_int + ) + return log_prob + confidence_score + + sample_scores = [score_sample(sample) for sample in samples] + best_sample_idx = jnp.argmax(jnp.array(sample_scores)) + return samples[best_sample_idx], None