From 378d8b60771cb952685bcf7ed5c2a012d45a1007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A0=E7=84=95=E9=94=AD?= Date: Thu, 4 Jan 2024 15:04:13 +0800 Subject: [PATCH 1/2] fix the tokenize process of chatglm3 --- fastchat/conversation.py | 4 ++-- fastchat/model/model_chatglm.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index ac983c18e..d80de09f3 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -172,7 +172,7 @@ def get_prompt(self) -> str: ret += system_prompt for role, message in self.messages: if message: - ret += role + "\n" + " " + message + ret += role + "\n" + message else: ret += role return ret @@ -487,7 +487,7 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="chatglm3", - system_template="<|system|>\n {system_message}", + system_template="<|system|>\n{system_message}", roles=("<|user|>", "<|assistant|>"), sep_style=SeparatorStyle.CHATGLM3, stop_token_ids=[ diff --git a/fastchat/model/model_chatglm.py b/fastchat/model/model_chatglm.py index 5d4db62bc..f967f5d76 100644 --- a/fastchat/model/model_chatglm.py +++ b/fastchat/model/model_chatglm.py @@ -36,6 +36,27 @@ def process_response(response): response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) return response +def recover_message_list(prompt): + role_token_pattern = "|".join([re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]]) + role = None + last_end_idx = -1 + message_list = [] + for match in re.finditer(role_token_pattern, prompt): + if role: + messge = {} + if role == "<|system|>": + messge["role"] = "system" + elif role == "<|user|>": + messge["role"] = "user" + else: + messge["role"] = "assistant" + messge["content"] = prompt[last_end_idx + 1: match.start()] + message_list.append(messge) + + role = prompt[match.start(): match.end()] + last_end_idx = match.end() + + return message_list @torch.inference_mode() def generate_stream_chatglm( @@ -54,7 +75,15 @@ def generate_stream_chatglm( max_new_tokens = int(params.get("max_new_tokens", 256)) echo = params.get("echo", True) - inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + model_type = str(type(model)).lower() + if "peft" in model_type: + model_type = str(type(model.base_model.model)).lower() + + if "chatglm3" in model_type: + message_list = recover_message_list(prompt) + inputs = tokenizer.build_chat_input(query=message_list[-1]["content"], history=message_list[:-1], role='user').to(model.device) + else: + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) input_echo_len = len(inputs["input_ids"][0]) gen_kwargs = { From e53ba800e063606dd4c2f6c66e15bca08a5dc6c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A0=E7=84=95=E9=94=AD?= Date: Fri, 5 Jan 2024 12:22:23 +0800 Subject: [PATCH 2/2] format --- fastchat/model/model_chatglm.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/fastchat/model/model_chatglm.py b/fastchat/model/model_chatglm.py index f967f5d76..2cbac8bc5 100644 --- a/fastchat/model/model_chatglm.py +++ b/fastchat/model/model_chatglm.py @@ -36,8 +36,11 @@ def process_response(response): response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) return response + def recover_message_list(prompt): - role_token_pattern = "|".join([re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]]) + role_token_pattern = "|".join( + [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] + ) role = None last_end_idx = -1 message_list = [] @@ -50,14 +53,15 @@ def recover_message_list(prompt): messge["role"] = "user" else: messge["role"] = "assistant" - messge["content"] = prompt[last_end_idx + 1: match.start()] + messge["content"] = prompt[last_end_idx + 1 : match.start()] message_list.append(messge) - role = prompt[match.start(): match.end()] + role = prompt[match.start() : match.end()] last_end_idx = match.end() return message_list + @torch.inference_mode() def generate_stream_chatglm( model, @@ -81,7 +85,9 @@ def generate_stream_chatglm( if "chatglm3" in model_type: message_list = recover_message_list(prompt) - inputs = tokenizer.build_chat_input(query=message_list[-1]["content"], history=message_list[:-1], role='user').to(model.device) + inputs = tokenizer.build_chat_input( + query=message_list[-1]["content"], history=message_list[:-1], role="user" + ).to(model.device) else: inputs = tokenizer([prompt], return_tensors="pt").to(model.device) input_echo_len = len(inputs["input_ids"][0])