Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow us to set the dtype (for 'cpu' device) #23

Open
jukofyork opened this issue Jun 5, 2024 · 0 comments
Open

Allow us to set the dtype (for 'cpu' device) #23

jukofyork opened this issue Jun 5, 2024 · 0 comments

Comments

@jukofyork
Copy link

jukofyork commented Jun 5, 2024

It would be helpful if we could pass the data type instead of defaulting to dtype=torch.bfloat16.

I managed to get this running using the 'cpu' device with OpenMP by hacking the setting to dtype=torch.float32 but with the default dtype=torch.bfloat16 it just sat all night using 1 core and never progressed further (on an older Xeon without native bf16 support, so likely it was trying to upcast into fp32).


I also tried preloading the model to use 4bit (with 'cuda' device), which should work for llama models:

        if hf_model is not None:
            hf_cfg = hf_model.config.to_dict()
            qc = hf_cfg.get("quantization_config", {})
            load_in_4bit = qc.get("load_in_4bit", False)
            load_in_8bit = qc.get("load_in_8bit", False)
            quant_method = qc.get("quant_method", "")
            assert not load_in_8bit, "8-bit quantization is not supported"
            assert not (
                load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
            ), "Quantization is only supported for torch versions >= 2.1.1"
            assert not (
                load_in_4bit and ("llama" not in model_name.lower())
            ), "Quantization is only supported for Llama models"
            if load_in_4bit:
                assert (
                    qc.get("quant_method", "") == "bitsandbytes"
                ), "Only bitsandbytes quantization is supported"
        else:
            hf_cfg = {}

from: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/HookedTransformer.py

But I got the same shape mismatch exception as mentioned in this thread:

TransformerLensOrg/TransformerLens#569

Might be worth adding the ability to use 4bit if they fix this bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant