Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W4A8 model larger than W4A16 #1215

Open
chmeyers opened this issue Feb 28, 2025 · 3 comments
Open

W4A8 model larger than W4A16 #1215

chmeyers opened this issue Feb 28, 2025 · 3 comments
Assignees
Labels
compressed-tensors Relates to compressed-tensors question Further information is requested

Comments

@chmeyers
Copy link

Describe the bug
I ran the example w4a16 script on meta-llama/Llama-3.1-8B-Instruct with both W4A16 and W4A8 schemes, and the W4A8 model was much larger. Specifically, the W4A16 model came out to 5,700,595,200 bytes, and the W4A8 model was 9,190,252,544 bytes. (5.7GB vs 9.2GB; values taken from model.safetensors.index.json but they seem to match the size on disk)

The W4A16 model seems to be the correct size, but the W4A8 model seems to be the of similar size to a W8A8. Maybe the weight tensors are being saved using 8 bits?

Expected behavior
W4A8 should be smaller, right?

Environment
Include all relevant environment information:

  1. OS: Amazon Linux 2023.6.20250203

  2. Python version: 3.10

  3. llmcompressor==0.4.1

  4. ML framework version(s) [e.g. torch 2.3.1]:

  5. Other Python package versions [e.g. vLLM, compressed-tensors, numpy, ONNX]:
    compressed-tensors==0.9.0
    accelerate==1.1.1
    onnx==1.17.0
    optimum==1.23.3
    transformers==4.47.0
    torch==2.5.1
    vllm==0.7.0
    ray==2.40.0
    numpy==1.26.4

  6. Other relevant environment information [e.g. hardware, CUDA version]:
    Ran on a Ray node on a AWS g6e.24xlarge (4xL40 GPUs, but it only used one for this model.)

To Reproduce
Exact steps to reproduce the behavior:
I used this example: https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py
Ran it twice, once with the scheme changed to w4a8

Errors
N/A

Additional context
Add any other context about the problem here. Also include any relevant files.

@chmeyers chmeyers added the bug Something isn't working label Feb 28, 2025
@dsikka dsikka self-assigned this Feb 28, 2025
@dsikka
Copy link
Collaborator

dsikka commented Mar 3, 2025

Hi @chmeyers - can you share the config produced as well as the recipe that you applied?

@chmeyers
Copy link
Author

chmeyers commented Mar 3, 2025

Recipes were:
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
recipe = GPTQModifier(targets="Linear", scheme="W4A8", ignore=["lm_head"])

config.json of the W4A8 model was:
{
"_name_or_path": "/tmp/tmp8k15e6xr/meta-llama/Llama-3.1-8B-Instruct",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"quantization_config": {
"config_groups": {
"group_0": {
"input_activations": {
"actorder": null,
"block_structure": null,
"dynamic": true,
"group_size": null,
"num_bits": 8,
"observer": null,
"observer_kwargs": {},
"strategy": "token",
"symmetric": true,
"type": "int"
},
"output_activations": null,
"targets": [
"Linear"
],
"weights": {
"actorder": null,
"block_structure": null,
"dynamic": false,
"group_size": 128,
"num_bits": 4,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "group",
"symmetric": true,
"type": "int"
}
}
},
"format": "int-quantized",
"global_compression_ratio": 1.8917232374233346,
"ignore": [
"lm_head"
],
"kv_cache_scheme": null,
"quant_method": "compressed-tensors",
"quantization_status": "compressed"
},
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.47.0",
"use_cache": true,
"vocab_size": 128256
}

@dsikka
Copy link
Collaborator

dsikka commented Mar 5, 2025

Hi @chmeyers the variation that you're seeing is because of the compressor that is being applied when saving the quantized model to disk. When doing weight only quantization, (W4A16/W8A16) we select the packed_quantized compressor. When adding in activation quantization, we select the int-quantized or float-quantized compressor.

You can see further details on how the compressor is selected, by referring to the docstring and functions listed here.

@dsikka dsikka added question Further information is requested compressed-tensors Relates to compressed-tensors and removed bug Something isn't working labels Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compressed-tensors Relates to compressed-tensors question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants