You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be helpful if we could pass the data type instead of defaulting to dtype=torch.bfloat16.
I managed to get this running using the 'cpu' device with OpenMP by hacking the setting to dtype=torch.float32 but with the default dtype=torch.bfloat16 it just sat all night using 1 core and never progressed further (on an older Xeon without native bf16 support, so likely it was trying to upcast into fp32).
I also tried preloading the model to use 4bit (with 'cuda' device), which should work for llama models:
ifhf_modelisnotNone:
hf_cfg=hf_model.config.to_dict()
qc=hf_cfg.get("quantization_config", {})
load_in_4bit=qc.get("load_in_4bit", False)
load_in_8bit=qc.get("load_in_8bit", False)
quant_method=qc.get("quant_method", "")
assertnotload_in_8bit, "8-bit quantization is not supported"assertnot (
load_in_4bitand (version.parse(torch.__version__) <version.parse("2.1.1"))
), "Quantization is only supported for torch versions >= 2.1.1"assertnot (
load_in_4bitand ("llama"notinmodel_name.lower())
), "Quantization is only supported for Llama models"ifload_in_4bit:
assert (
qc.get("quant_method", "") =="bitsandbytes"
), "Only bitsandbytes quantization is supported"else:
hf_cfg= {}
It would be helpful if we could pass the data type instead of defaulting to
dtype=torch.bfloat16
.I managed to get this running using the 'cpu' device with OpenMP by hacking the setting to
dtype=torch.float32
but with the defaultdtype=torch.bfloat16
it just sat all night using 1 core and never progressed further (on an older Xeon without native bf16 support, so likely it was trying to upcast into fp32).I also tried preloading the model to use 4bit (with 'cuda' device), which should work for
llama
models:from: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/HookedTransformer.py
But I got the same shape mismatch exception as mentioned in this thread:
TransformerLensOrg/TransformerLens#569
Might be worth adding the ability to use 4bit if they fix this bug.
The text was updated successfully, but these errors were encountered: