Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] PTQ for generate_v2 #1866

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions recipes/configs/llama2/generation_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this commented out until the user wants to do something with it.

# quantization_method:
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory
# use_hqq: False # Turn on for more accurate results

# Transform arguments
tokenizer:
Expand All @@ -27,16 +31,16 @@ checkpointer:
output_dir: ./
model_type: LLAMA2

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
max_new_tokens: 200
max_new_tokens: 500
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO
53 changes: 31 additions & 22 deletions recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtune.generation import sample

from torchtune.modules.transforms import Transform
from torchtune.training import compile_model


class SingleTurnYAMLToMessages(Transform):
Expand Down Expand Up @@ -65,29 +66,37 @@ class InferenceRecipe:

This *does not* currently support the following features:
- torch.compile
- quantization through torchao
- multi-GPU generation
- batch generation
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
self.device = utils.get_device(device=cfg.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a public recipe, no need to be a "private" variable.

cc @pbontrager

self.dtype = training.get_dtype(dtype=cfg.dtype, device=self.device)
self.logger = utils.get_logger(cfg.log_level)
training.set_seed(seed=cfg.seed)

def setup(self, cfg: DictConfig) -> None:
"""Setup the model and transforms."""
# Load checkpointer and state_dict
# Load checkpointer
_checkpointer = config.instantiate(cfg.checkpointer)
_ckpt_dict = _checkpointer.load_checkpoint()

# Instantiate model
with training.set_default_dtype(self._dtype), self._device:
with training.set_default_dtype(self.dtype), self.device:
model = config.instantiate(cfg.model)
model.load_state_dict(_ckpt_dict[training.MODEL_KEY])
self.logger.info(f"Model was initialized with precision {self.dtype}.")

# Quantize the model if specified
if cfg.get("quantization_method") is not None:
from torchao.quantization.quant_api import quantize_
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazily import torchao API


quantization_method = config.instantiate(cfg.quantization_method)
compile_model(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiling the model is necessary for quantization to be really worth it

quantize_(model, quantization_method, device=self.device)

self.model = model
self._logger.info(f"Model was initialized with precision {self._dtype}.")

# Instantiate transforms
self.model_transform = config.instantiate(cfg.tokenizer)
Expand All @@ -105,13 +114,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
for p in itertools.chain(self.model.parameters(), self.model.buffers())
]
)
self._logger.info(
self.logger.info(
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
self.logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
self.logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)

Expand All @@ -128,10 +137,10 @@ def generate(self, cfg: DictConfig):
total_response_length = seq_len + cfg.max_new_tokens

# 3. Setup KV cache
with self._device:
with self.device:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
dtype=self.dtype,
encoder_max_seq_len=(
self.model_transform.image_seq_len if is_multimodal_input else None
),
Expand All @@ -143,7 +152,7 @@ def generate(self, cfg: DictConfig):
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
device=self._device,
device=self.device,
)
)
input_pos = torch.arange(total_response_length)
Expand All @@ -155,20 +164,20 @@ def generate(self, cfg: DictConfig):
[model_inputs], pad_direction="left", pad_max_images=1
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
prompt = batch.pop("tokens").to(self.device)
else:
prompt = torch.tensor(
model_inputs["tokens"], device=self._device
).unsqueeze(0)
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted this to fix on one line lol

batch["mask"] = causal_mask[None, :seq_len]
batch["input_pos"] = input_pos[None, :seq_len]
utils.batch_to_device(batch, self._device)
utils.batch_to_device(batch, self.device)

# 6. Prefill step
generated_tokens = []
t0 = time.perf_counter()
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
t1 = time.perf_counter()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we might have a warmup run, we log this differently so the user can see how much help quantization / compilation helps.

self.logger.info(f"Time to generate first token: {t1 - t0:.02f} sec")
generated_tokens.append(token.item())

if is_multimodal_input:
Expand All @@ -192,15 +201,15 @@ def generate(self, cfg: DictConfig):
generated_tokens.append(token.item())
seq_len += 1

t = time.perf_counter() - t0
t2 = time.perf_counter() - t1

# 8. Translate tokens back to text
decoded = self.model_transform.decode(generated_tokens)
self._logger.info(f"\n\n{decoded}\n")
self.logger.info(f"\n{decoded}\n")

# 9. Log metrics
tokens_per_second = len(generated_tokens) / t
self.log_metrics(total_time=t, tokens_per_second=tokens_per_second)
tokens_per_second = len(generated_tokens) / t2
self.log_metrics(total_time=t2, tokens_per_second=tokens_per_second)


@config.parse
Expand Down
52 changes: 51 additions & 1 deletion tests/recipes/dev/test_generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from tests.common import TUNE_PATH
from tests.recipes.utils import MODEL_TEST_CONFIGS, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS, mps_ignored_test, TOKENIZER_PATHS
from tests.test_utils import (
CKPT_MODEL_PATHS,
gpu_test,
mps_ignored_test,
TOKENIZER_PATHS,
)


class TestGenerateV2:
Expand Down Expand Up @@ -62,6 +67,51 @@ def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir):
logs = caplog.text
assert expected_output in logs

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_llama2_generate_with_quantization(self, caplog, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS["llama2"])
ckpt_dir = ckpt_path.parent

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)

cmd = f"""
tune run dev/generate_v2 \
--config llama2/generation_v2 \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
device=cuda \
dtype=bf16 \
max_new_tokens=10 \
seed=123 \
quantization_method._component_=torchao.quantization.quant_api.int4_weight_only \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2"]
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# this is gibberish b/c the model is random weights, but it's
# the expected value for what we currently have in V2
# this test should catch any changes to the generate recipe that affect output
expected_output = (
"Halfotherтература retir pushingroad Chem CURLorientationocation Stadium"
)

logs = caplog.text
assert expected_output in logs

@pytest.mark.integration_test
def test_llama2_fail_on_bad_input(self, capsys, monkeypatch, tmpdir):
"""Should fail when user passes in a bad input:
Expand Down
Loading