-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] PTQ for generate_v2
#1866
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from torchtune.generation import sample | ||
|
||
from torchtune.modules.transforms import Transform | ||
from torchtune.training import compile_model | ||
|
||
|
||
class SingleTurnYAMLToMessages(Transform): | ||
|
@@ -65,29 +66,37 @@ class InferenceRecipe: | |
|
||
This *does not* currently support the following features: | ||
- torch.compile | ||
- quantization through torchao | ||
- multi-GPU generation | ||
- batch generation | ||
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
self.device = utils.get_device(device=cfg.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a public recipe, no need to be a "private" variable. cc @pbontrager |
||
self.dtype = training.get_dtype(dtype=cfg.dtype, device=self.device) | ||
self.logger = utils.get_logger(cfg.log_level) | ||
training.set_seed(seed=cfg.seed) | ||
|
||
def setup(self, cfg: DictConfig) -> None: | ||
"""Setup the model and transforms.""" | ||
# Load checkpointer and state_dict | ||
# Load checkpointer | ||
_checkpointer = config.instantiate(cfg.checkpointer) | ||
_ckpt_dict = _checkpointer.load_checkpoint() | ||
|
||
# Instantiate model | ||
with training.set_default_dtype(self._dtype), self._device: | ||
with training.set_default_dtype(self.dtype), self.device: | ||
model = config.instantiate(cfg.model) | ||
model.load_state_dict(_ckpt_dict[training.MODEL_KEY]) | ||
self.logger.info(f"Model was initialized with precision {self.dtype}.") | ||
|
||
# Quantize the model if specified | ||
if cfg.get("quantization_method") is not None: | ||
from torchao.quantization.quant_api import quantize_ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lazily import torchao API |
||
|
||
quantization_method = config.instantiate(cfg.quantization_method) | ||
compile_model(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compiling the model is necessary for quantization to be really worth it |
||
quantize_(model, quantization_method, device=self.device) | ||
|
||
self.model = model | ||
self._logger.info(f"Model was initialized with precision {self._dtype}.") | ||
|
||
# Instantiate transforms | ||
self.model_transform = config.instantiate(cfg.tokenizer) | ||
|
@@ -105,13 +114,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: | |
for p in itertools.chain(self.model.parameters(), self.model.buffers()) | ||
] | ||
) | ||
self._logger.info( | ||
self.logger.info( | ||
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" | ||
) | ||
self._logger.info( | ||
self.logger.info( | ||
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | ||
) | ||
self._logger.info( | ||
self.logger.info( | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" | ||
) | ||
|
||
|
@@ -128,10 +137,10 @@ def generate(self, cfg: DictConfig): | |
total_response_length = seq_len + cfg.max_new_tokens | ||
|
||
# 3. Setup KV cache | ||
with self._device: | ||
with self.device: | ||
self.model.setup_caches( | ||
batch_size=1, | ||
dtype=self._dtype, | ||
dtype=self.dtype, | ||
encoder_max_seq_len=( | ||
self.model_transform.image_seq_len if is_multimodal_input else None | ||
), | ||
|
@@ -143,7 +152,7 @@ def generate(self, cfg: DictConfig): | |
torch.ones( | ||
size=(total_response_length, total_response_length), | ||
dtype=torch.bool, | ||
device=self._device, | ||
device=self.device, | ||
) | ||
) | ||
input_pos = torch.arange(total_response_length) | ||
|
@@ -155,20 +164,20 @@ def generate(self, cfg: DictConfig): | |
[model_inputs], pad_direction="left", pad_max_images=1 | ||
) | ||
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] | ||
prompt = batch.pop("tokens").to(self._device) | ||
prompt = batch.pop("tokens").to(self.device) | ||
else: | ||
prompt = torch.tensor( | ||
model_inputs["tokens"], device=self._device | ||
).unsqueeze(0) | ||
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted this to fix on one line lol |
||
batch["mask"] = causal_mask[None, :seq_len] | ||
batch["input_pos"] = input_pos[None, :seq_len] | ||
utils.batch_to_device(batch, self._device) | ||
utils.batch_to_device(batch, self.device) | ||
|
||
# 6. Prefill step | ||
generated_tokens = [] | ||
t0 = time.perf_counter() | ||
logits = self.model(prompt, **batch)[:, -1] | ||
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) | ||
t1 = time.perf_counter() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we might have a warmup run, we log this differently so the user can see how much help quantization / compilation helps. |
||
self.logger.info(f"Time to generate first token: {t1 - t0:.02f} sec") | ||
generated_tokens.append(token.item()) | ||
|
||
if is_multimodal_input: | ||
|
@@ -192,15 +201,15 @@ def generate(self, cfg: DictConfig): | |
generated_tokens.append(token.item()) | ||
seq_len += 1 | ||
|
||
t = time.perf_counter() - t0 | ||
t2 = time.perf_counter() - t1 | ||
|
||
# 8. Translate tokens back to text | ||
decoded = self.model_transform.decode(generated_tokens) | ||
self._logger.info(f"\n\n{decoded}\n") | ||
self.logger.info(f"\n{decoded}\n") | ||
|
||
# 9. Log metrics | ||
tokens_per_second = len(generated_tokens) / t | ||
self.log_metrics(total_time=t, tokens_per_second=tokens_per_second) | ||
tokens_per_second = len(generated_tokens) / t2 | ||
self.log_metrics(total_time=t2, tokens_per_second=tokens_per_second) | ||
|
||
|
||
@config.parse | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave this commented out until the user wants to do something with it.