diff --git a/entropix/sampler.py b/entropix/sampler.py index e42f3e8..dd71b60 100644 --- a/entropix/sampler.py +++ b/entropix/sampler.py @@ -32,6 +32,7 @@ def _sample( logits: jax.Array, *, temperature: float | jax.Array, top_p: float p_max = jnp.max(probs, axis=-1, keepdims=True) indices_to_remove = probs < (min_p * p_max) logit = jnp.where(indices_to_remove, jnp.full_like(logit, float('-inf')), logit) + probs = jax.nn.softmax(logit / temperature, axis=-1) # Apply top-k sampling top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k)