Skip to content

Commit

Permalink
support deepspeed tp load int4 checkpoint (#3328)
Browse files Browse the repository at this point in the history
* tp int4 checkpoint

* small changes.

* update

* update lm_head

* unnecessary change

* some change on tp

* tp update

* simplify the code.

* modify run_accuracy_with_deepspeed.

* modify according to comment.

* Support low precision checkpoint with TP in llm.optimize

* Revert some changes in llm.optimize

* fix bug for gpt-j.

* remove unnecessary change.

* remove unnecessary change.

* fix bug.

* support mixtral.

* support mixtral.

* flake8 format.

* fix bug.

---------

Co-authored-by: Tao, Ran <[email protected]>
Co-authored-by: Xia, Weiwen <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 4679764 commit ae09c58
Show file tree
Hide file tree
Showing 7 changed files with 494 additions and 64 deletions.
183 changes: 183 additions & 0 deletions examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def decorator(func):
help="Quantize weight symmetrically for weight only quantization. It usually brings better latency at"
" the cost of accuracy. It has not effect if you are loading low-precision checkpoints.",
)
parser.add_argument(
"--low-precision-checkpoint",
default="",
type=str,
help="Low precision checkpoint file generated by algorithms, such as GPTQ. It contains"
" INT4 weights, scales, zero points, etc. For better accuracy of weight only"
" quantization with INT4 weight.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -394,7 +402,84 @@ def write_checkpoints_json():
)

self.model = self.model.module
import pathlib

if args.low_precision_checkpoint != "":
pathname = args.low_precision_checkpoint
assert os.path.exists(
pathname
), f"Checkpoint file does not exist: {pathname}"
if os.path.isfile(pathname):
low_precision_checkpoint = None
if pathname.endswith(".pt") or pathname.endswith(".pth"):
low_precision_checkpoint = torch.load(pathname, weights_only=True)
elif pathname.endswith(".safetensors"):
try:
import safetensors
except ImportError:
print(
"Please install safetensors package to load safetensors checkpoint."
)
exit(1)
low_precision_checkpoint = safetensors.torch.load_file(pathname)
assert (
low_precision_checkpoint is not None
), f"Invalid checkpoint file: {pathname}. Should be a .pt, .pth or .safetensors file."

quant_method = {"quant_method": "gptq"}

elif os.path.isdir(pathname):
low_precision_checkpoint = {}
for pattern in ["*.pt", "*.pth"]:
files = list(pathlib.Path(pathname).glob(pattern))
if files:
for f in files:
data_f = torch.load(f, weights_only=True)
low_precision_checkpoint.update(data_f)
break
if not low_precision_checkpoint:
files = list(pathlib.Path(pathname).glob("*.safetensors"))
if files:
try:
import safetensors
except ImportError:
print(
"Please install safetensors package to load safetensors checkpoint."
)
exit(1)
for f in files:
data_f = safetensors.torch.load_file(f)
low_precision_checkpoint.update(data_f)
assert (
len(low_precision_checkpoint) > 0
), f"Cannot find checkpoint (.pt/.pth/.safetensors) files in path {pathname}."

try:
with open(pathname + "/config.json") as f:
quant_model_config = json.load(f)
quant_method = {
"quant_method": quant_model_config["quantization_config"][
"quant_method"
]
}
except Exception as e:
print(
"warning: loading HF config.json to get `quant_method` failed, due to ",
e,
)
print("warning: specifying `quant_method` = `gptq` by default.")
quant_method = {"quant_method": "gptq"}

else:
raise AssertionError(
f"Invalid low-precision-checkpoint: {pathname}."
" Should be a .pt/.pth/.safetensors file or a directory containing them."
)

low_precision_checkpoint = (low_precision_checkpoint, quant_method)
low_precision_checkpoint_dict = low_precision_checkpoint[0]
else:
low_precision_checkpoint = None
if self._with_ipex:
ipex_woq_enabled = args.ipex_weight_only_quantization
if ipex_woq_enabled:
Expand Down Expand Up @@ -447,13 +532,111 @@ def write_checkpoints_json():
group_size=args.group_size,
weight_qscheme=weight_qscheme,
)
model = self.model
if low_precision_checkpoint is not None:
num_heads = model.config.num_attention_heads
rank = local_rank

layers_split_by_N = [
"q_proj",
"k_proj",
"v_proj",
"gate_proj",
"up_proj",
"fc_in",
"fc1",
"query_key_value",
"w1",
"w3",
]
layers_split_by_K = [
"o_proj",
"down_proj",
"fc_out",
"fc2",
"out_proj",
"dense",
"dense_4h_to_h",
"w2",
]
lm_head_layers = ["lm_head"] # split by K but not quantized
quantization_method = quant_model_config["quantization_config"][
"quant_method"
]
head_range = [0]
head_per_rank = num_heads // world_size

for i in range(0, world_size):
head_this_rank = head_per_rank
if i < num_heads % world_size:
head_this_rank += 1
head_range.append(head_range[-1] + head_this_rank)
for key in low_precision_checkpoint[0].keys():
q_head_start = head_range[rank]
q_head_end = q_head_start + (
head_range[rank + 1] - head_range[rank]
)
if "bias" in key:
continue
if any(substring in key for substring in layers_split_by_N):
data = low_precision_checkpoint_dict[key]
if quantization_method == "awq":
# awq qweight: [K, N // 8]
# awq scales: [K // G, N]
# awq qzeros: [K // G, N // 8]
dim = data.shape[-1] // head_range[-1]
low_precision_checkpoint_dict[key] = data[
:, q_head_start * dim : q_head_end * dim
]
else:
raise AssertionError(
f"{quantization_method} is not supported yet."
)
if any(substring in key for substring in layers_split_by_K):
data = low_precision_checkpoint_dict[key]
if quantization_method == "awq":
# awq qweight: [K, N // 8]
# awq scales: [K // G, N]
# awq qzeros: [K // G, N // 8]
if data.shape[0] % head_range[-1] == 0:
dim = data.shape[0] // head_range[-1]
else:
assert data.shape[0] % world_size == 0
dim = data.shape[0] // world_size
q_head_start = local_rank
q_head_end = local_rank + 1
low_precision_checkpoint_dict[key] = data[
q_head_start * dim : q_head_end * dim
]
else:
raise AssertionError(
f"{quantization_method} is not supported yet."
)
if any(substring in key for substring in lm_head_layers):
# lm_head: [N, K] (not quantized)
# Same for both AWQ and GPTQ
data = low_precision_checkpoint_dict[key]
if data.shape[-1] % head_range[-1] == 0:
dim = data.shape[-1] // head_range[-1]
else:
dim = data.shape[-1] // world_size
q_head_start = local_rank
q_head_end = local_rank + 1
low_precision_checkpoint_dict[key] = data[
:, q_head_start * dim : q_head_end * dim
]
low_precision_dict = (low_precision_checkpoint_dict, quant_method)
else:
low_precision_dict = None

self.model = ipex.llm.optimize(
self.model.eval(),
dtype=infer_dtype,
quantization_config=qconfig if ipex_woq_enabled else None,
inplace=True,
deployment_mode=False,
cache_weight_for_large_batch=args.cache_weight_for_large_batch,
low_precision_checkpoint=low_precision_dict,
)

self.base_model = self.model
Expand Down
Loading

0 comments on commit ae09c58

Please sign in to comment.