Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] boring stuff: define a sampler interface #31

Open
qdbp opened this issue Oct 7, 2024 · 2 comments · May be fixed by #40
Open

[RFC] boring stuff: define a sampler interface #31

qdbp opened this issue Oct 7, 2024 · 2 comments · May be fixed by #40

Comments

@qdbp
Copy link
Contributor

qdbp commented Oct 7, 2024

tl;dr samplers should be swappable and composable. for that we need a common interface

There's a lot of hot stuff in the pipeline re. MCTS, the vanilla sampler, etc.

One thing I'm afraid of is that there's going to be a lot of spaghetti involving bespoke/subtly different ways to call different samplers, which will make benchmarking and comparison painful.

I want to get ahead of this issue by defining a common interface to samplers. Since this is Python I think this should be a Protocol, something like:

class Sampler[Cfg, ST](Protocol):
     def __call__(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array,
                  cfg: Cfg, state: ST) -> tuple[jax.Array, ST]:
          ...

which is a light touch (no need to inherit) but can still be checked. This should be a generic enough framework for people to be able to implement their favorite MuZero etc. and have it all plug in to the same harnesses.

My goal here is to have an easy to maintain sampler benchmarking suite with easy plug and play samplers.

EDIT

given the jax idiom of passing and returning state as an argument (and to support some sampler work of my own, tee hee), I think it will make sense to expand this interface to include a ST type var.

EDIT 2

since we're returning a jax.Array in place of the token output, I propose that it be acceptable to return either a single or an entire sequence of tokens at once from a sample call. the callers should be able to handle either case (and, really, a single token is just a sequence of length 1)

@qdbp qdbp changed the title boring stuff: define a sampler interface [RFC] boring stuff: define a sampler interface Oct 7, 2024
@aw632
Copy link

aw632 commented Oct 7, 2024

I am working on adding Entropix to vLLM so I can plug it into my inference workflows. Will reference this issue here

qdbp added a commit to qdbp/entropix that referenced this issue Oct 7, 2024
qdbp added a commit to qdbp/entropix that referenced this issue Oct 8, 2024
qdbp added a commit to qdbp/entropix that referenced this issue Oct 8, 2024
@qdbp
Copy link
Contributor Author

qdbp commented Oct 8, 2024

we should fold the "key" into either the config or the state. having it as a top level arg is messy.

if it should be common to all models dataclass inheritance will come in handy here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants