Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom masks as optional input for models for batch processing #1044

Open
nath1295 opened this issue Oct 13, 2024 · 0 comments
Open

Comments

@nath1295
Copy link

With a quick read of the code, it seems that the attention mask tensor is created on the fly during inference. The mask is then broadcasted to all the prompt token sequences in individual layers (normally it's one, but to allow batch inferences, we should not assume this). This might cause a problem during batch inference as we cannot mask the padded tokens for prompts with different lengths. One way to do it now is to process those pad tokens as well, but this will change the output. Just want to make sure my understanding here is correct.

It will be beneficial to be able to customise these mask tensors so that we can just fill the prompt cache with zeros for padded tokens and it won't affect the output. I am not sure if this is possible.

A simple example use case:

from mlx_lm import load
from mlx_lm.models.cache import make_prompt_cache
import mlx.core as mx

model, tokenizer = load('my/model/path')
pad_token_id = tokenizer.bos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id

# Get the lists of token ids for all the prompts
prompts = [
    'The weather is nice out there',
    'The weather is awful out there, and this is a longer prompt'
]
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]

# Get the masks for each token in all the prompts
prompt_lens = [len(pt) for pt in prompt_tokens]
max_prompt_len = max(prompt_lens)
mask = [[-1] * (max_prompt_len - n) + tks for tks, n in zip(prompt_tokens, prompt_lens)]
mask = (mx.array(mask) != -1).astype(mx.int16)

# Pad the shorter prompts
prompt_tokens = [[pad_token_id] * (max_prompt_len - n) + tks for tks, n in zip(prompt_tokens, prompt_lens)]
prompt_tokens = mx.array(prompt_tokens)

# Make the cache
cache = make_prompt_cache(model)

# Get the logits for the next token for each prompt
logits = model(prompt_tokens, mask=mask, cache=cache)

I realise this might take a lot of rework, but I am just wondering if this is possible?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant