diff --git a/recipes/configs/llama2/generation_v2.yaml b/recipes/configs/llama2/generation_v2.yaml index 7ce4e2c43d..07ee7830a9 100644 --- a/recipes/configs/llama2/generation_v2.yaml +++ b/recipes/configs/llama2/generation_v2.yaml @@ -9,6 +9,10 @@ # Model arguments model: _component_: torchtune.models.llama2.llama2_7b +# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM +# quantization_method: +# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory +# use_hqq: False # Turn on for more accurate results # Transform arguments tokenizer: @@ -27,16 +31,16 @@ checkpointer: output_dir: ./ model_type: LLAMA2 -# Device -device: cuda -dtype: bf16 -seed: 1234 -log_level: INFO - # Generation arguments prompt: system: You are a helpful and creative AI assistant. user: What is the capital of France? -max_new_tokens: 200 +max_new_tokens: 500 temperature: 0.6 # 0.8 and 0.6 are popular values to try top_k: 300 + +# Device +device: cuda +dtype: bf16 +seed: 1234 +log_level: INFO diff --git a/recipes/dev/generate_v2.py b/recipes/dev/generate_v2.py index e63ea2dcb0..5688f86b01 100644 --- a/recipes/dev/generate_v2.py +++ b/recipes/dev/generate_v2.py @@ -17,6 +17,7 @@ from torchtune.generation import sample from torchtune.modules.transforms import Transform +from torchtune.training import compile_model class SingleTurnYAMLToMessages(Transform): @@ -65,29 +66,37 @@ class InferenceRecipe: This *does not* currently support the following features: - torch.compile - - quantization through torchao - multi-GPU generation - batch generation """ def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) - self._logger = utils.get_logger(cfg.log_level) + self.device = utils.get_device(device=cfg.device) + self.dtype = training.get_dtype(dtype=cfg.dtype, device=self.device) + self.logger = utils.get_logger(cfg.log_level) training.set_seed(seed=cfg.seed) def setup(self, cfg: DictConfig) -> None: """Setup the model and transforms.""" - # Load checkpointer and state_dict + # Load checkpointer _checkpointer = config.instantiate(cfg.checkpointer) _ckpt_dict = _checkpointer.load_checkpoint() # Instantiate model - with training.set_default_dtype(self._dtype), self._device: + with training.set_default_dtype(self.dtype), self.device: model = config.instantiate(cfg.model) model.load_state_dict(_ckpt_dict[training.MODEL_KEY]) + self.logger.info(f"Model was initialized with precision {self.dtype}.") + + # Quantize the model if specified + if cfg.get("quantization_method") is not None: + from torchao.quantization.quant_api import quantize_ + + quantization_method = config.instantiate(cfg.quantization_method) + compile_model(model) + quantize_(model, quantization_method, device=self.device) + self.model = model - self._logger.info(f"Model was initialized with precision {self._dtype}.") # Instantiate transforms self.model_transform = config.instantiate(cfg.tokenizer) @@ -105,13 +114,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: for p in itertools.chain(self.model.parameters(), self.model.buffers()) ] ) - self._logger.info( + self.logger.info( f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" ) - self._logger.info( + self.logger.info( f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" ) - self._logger.info( + self.logger.info( f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" ) @@ -128,10 +137,10 @@ def generate(self, cfg: DictConfig): total_response_length = seq_len + cfg.max_new_tokens # 3. Setup KV cache - with self._device: + with self.device: self.model.setup_caches( batch_size=1, - dtype=self._dtype, + dtype=self.dtype, encoder_max_seq_len=( self.model_transform.image_seq_len if is_multimodal_input else None ), @@ -143,7 +152,7 @@ def generate(self, cfg: DictConfig): torch.ones( size=(total_response_length, total_response_length), dtype=torch.bool, - device=self._device, + device=self.device, ) ) input_pos = torch.arange(total_response_length) @@ -155,20 +164,20 @@ def generate(self, cfg: DictConfig): [model_inputs], pad_direction="left", pad_max_images=1 ) batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] - prompt = batch.pop("tokens").to(self._device) + prompt = batch.pop("tokens").to(self.device) else: - prompt = torch.tensor( - model_inputs["tokens"], device=self._device - ).unsqueeze(0) + prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :] batch["mask"] = causal_mask[None, :seq_len] batch["input_pos"] = input_pos[None, :seq_len] - utils.batch_to_device(batch, self._device) + utils.batch_to_device(batch, self.device) # 6. Prefill step generated_tokens = [] t0 = time.perf_counter() logits = self.model(prompt, **batch)[:, -1] token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) + t1 = time.perf_counter() + self.logger.info(f"Time to generate first token: {t1 - t0:.02f} sec") generated_tokens.append(token.item()) if is_multimodal_input: @@ -192,15 +201,15 @@ def generate(self, cfg: DictConfig): generated_tokens.append(token.item()) seq_len += 1 - t = time.perf_counter() - t0 + t2 = time.perf_counter() - t1 # 8. Translate tokens back to text decoded = self.model_transform.decode(generated_tokens) - self._logger.info(f"\n\n{decoded}\n") + self.logger.info(f"\n{decoded}\n") # 9. Log metrics - tokens_per_second = len(generated_tokens) / t - self.log_metrics(total_time=t, tokens_per_second=tokens_per_second) + tokens_per_second = len(generated_tokens) / t2 + self.log_metrics(total_time=t2, tokens_per_second=tokens_per_second) @config.parse diff --git a/tests/conftest.py b/tests/conftest.py index 7618d393e0..bae7ab0ff1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import logging import os import uuid from pathlib import Path @@ -18,6 +19,13 @@ CACHE_ARTIFACTS_SCRIPT_PATH = root + "/tests/cache_artifacts.sh" +def pytest_sessionfinish(): + """ + Register a hook to suppress logging errors after the session finishes. + """ + logging.raiseExceptions = False + + def pytest_configure(config): """ This hook runs before each pytest invocation. Its purpose is to handle optional fetching diff --git a/tests/recipes/dev/test_generate_v2.py b/tests/recipes/dev/test_generate_v2.py index be3f995f58..208bfd281c 100644 --- a/tests/recipes/dev/test_generate_v2.py +++ b/tests/recipes/dev/test_generate_v2.py @@ -10,9 +10,16 @@ import pytest +import torch + from tests.common import TUNE_PATH from tests.recipes.utils import MODEL_TEST_CONFIGS, write_hf_ckpt_config -from tests.test_utils import CKPT_MODEL_PATHS, mps_ignored_test, TOKENIZER_PATHS +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gpu_test, + mps_ignored_test, + TOKENIZER_PATHS, +) class TestGenerateV2: @@ -62,6 +69,53 @@ def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir): logs = caplog.text assert expected_output in logs + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_llama2_generate_with_quantization(self, caplog, monkeypatch, tmpdir): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS["llama2"]) + ckpt_dir = ckpt_path.parent + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + + cmd = f""" + tune run dev/generate_v2 \ + --config llama2/generation_v2 \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + device=cuda \ + dtype=bf16 \ + max_new_tokens=10 \ + seed=123 \ + quantization_method._component_=torchao.quantization.quant_api.int4_weight_only \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama2"] + cmd = cmd + model_config + + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # this is gibberish b/c the model is random weights, but it's + # the expected value for what we currently have in V2 + # this test should catch any changes to the generate recipe that affect output + expected_output = ( + "Halfotherтература retir pushingroad Chem CURLorientationocation Stadium" + ) + + torch._dynamo.reset() + + logs = caplog.text + assert expected_output in logs + @pytest.mark.integration_test def test_llama2_fail_on_bad_input(self, capsys, monkeypatch, tmpdir): """Should fail when user passes in a bad input: