Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug of quantizing part of the Qwen model #1230

Open
Kha-Zix-1 opened this issue Mar 6, 2025 · 0 comments
Open

Bug of quantizing part of the Qwen model #1230

Kha-Zix-1 opened this issue Mar 6, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@Kha-Zix-1
Copy link

Kha-Zix-1 commented Mar 6, 2025

I want to quantize layers 12~23 of qwen0.5B. I used llmcompressor 0.3.1. Here is my script:

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor.transformers import oneshot

import time
import os
import logging
import math
from datasets import Dataset, DatasetDict
import json
from transformers import AutoTokenizer, TextGenerationPipeline
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import argparse

# 加载 JSON 文件

def load_json_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

def convert_to_message_format(raw_data):
    messages = []
    # 处理历史对话
    instruction = raw_data.get("instruction", "")
    input_data = raw_data.get("input", "")  # 保留空的 input 字段
    output = raw_data.get("output", "")
    history = raw_data.get("history", [])
    if history:
        for user_msg, assistant_msg in raw_data["history"]:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": assistant_msg})
    # 处理 instruction 和 input
    messages.append({"role": "user", "content": instruction  + " " + input_data})
    # 处理 output
    messages.append({"role": "assistant", "content": output})
    return {"messages": messages}

# Select model and load it.
MODEL_ID = "/NVME1/elecLLM/models/qwen/qwen1.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
data_source = "/NVME1/elecLLM/models/Quantized_models/stage2/"

raw_data_list_1 = load_json_file(data_source+'4000_new_style.json')
raw_data_list_2 = load_json_file(data_source+'general.json')
raw_data_list_3 = load_json_file(data_source+'rag_stage1.json')
raw_data_list_4 = load_json_file(data_source+'elec_rag_gpt4o_reanswer_xtuner_v2.json')
raw_data_list_5 = load_json_file(data_source+'online_qa_rag_4o_xtuner_v2.json')

raw_data_list = []
raw_data_list.extend(raw_data_list_1)
raw_data_list.extend(raw_data_list_2)
raw_data_list.extend(raw_data_list_3)

# 处理数据

ds = [convert_to_message_format(data) for data in raw_data_list]
ds.extend(raw_data_list_4)
ds.extend(raw_data_list_5)

# 将字典转换为 Pandas DataFrame
ds  = pd.DataFrame(ds)

# 将 DataFrame 转换为 Dataset
ds  = Dataset.from_pandas(ds)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }

ds = ds.map(preprocess)

# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )

ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure algorithms. In this case, we:
#   * apply SmoothQuant to make the activations easier to quantize
#   * quantize the weights to int8 with GPTQ (static per channel)
#   * quantize the activations to int8 (dynamic per token)

recipe = """
quant_stage:
    quant_modifiers:
        SmoothQuantModifier:
            ignore: ["lm_head", "re:model.layers.(0?[0-9]|1[01]).*$"]
            smoothing_strength: 0.8
        GPTQModifier:
            ignore: ["lm_head", "re:model.layers.(0?[0-9]|1[01]).*$"]
            config_groups:
                group_0:
                    weights:
                        num_bits: 8
                        type: int
                        strategy: channel
                        dynamic: false
                        symmetric: true
                    input_activations:
                        num_bits: 8
                        type: int
                        strategy: token
                        dynamic: true
                        symmetric: true
                    targets: ["Linear"]
"""

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
)

# Apply algorithms.

oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Save to disk compressed.
SAVE_DIR = "/NVME1/elecLLM/models/Quantized_models/Qwen-72B-Instruct-W8A8-INT8"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

I use ignore: ["lm_head", "re:model.layers.(0?[0-9]|1[01]).*$"] to avoid quantizing layer 0~11. However, after the quantization, the model config is

{
  "_name_or_path": "/NVME1/elecLLM/models/qwen/qwen1.5-0.5B",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 2816,
  "max_position_embeddings": 32768,
  "max_window_layers": 21,
  "model_type": "qwen2",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "num_key_value_heads": 16,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": 32768,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.49.0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

It seem that the quantization is not applied to the model. How to solve this problem? How can I just quantize part of a model? Looking forward to your help!

@Kha-Zix-1 Kha-Zix-1 added the bug Something isn't working label Mar 6, 2025
@Kha-Zix-1 Kha-Zix-1 changed the title Bug of quantizing part of the models Bug of quantizing part of the Qwen model Mar 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant