diff --git a/fastchat/conversation.py b/fastchat/conversation.py index ac983c18e..d80de09f3 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -172,7 +172,7 @@ def get_prompt(self) -> str: ret += system_prompt for role, message in self.messages: if message: - ret += role + "\n" + " " + message + ret += role + "\n" + message else: ret += role return ret @@ -487,7 +487,7 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="chatglm3", - system_template="<|system|>\n {system_message}", + system_template="<|system|>\n{system_message}", roles=("<|user|>", "<|assistant|>"), sep_style=SeparatorStyle.CHATGLM3, stop_token_ids=[ diff --git a/fastchat/model/model_chatglm.py b/fastchat/model/model_chatglm.py index 5d4db62bc..2cbac8bc5 100644 --- a/fastchat/model/model_chatglm.py +++ b/fastchat/model/model_chatglm.py @@ -37,6 +37,31 @@ def process_response(response): return response +def recover_message_list(prompt): + role_token_pattern = "|".join( + [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] + ) + role = None + last_end_idx = -1 + message_list = [] + for match in re.finditer(role_token_pattern, prompt): + if role: + messge = {} + if role == "<|system|>": + messge["role"] = "system" + elif role == "<|user|>": + messge["role"] = "user" + else: + messge["role"] = "assistant" + messge["content"] = prompt[last_end_idx + 1 : match.start()] + message_list.append(messge) + + role = prompt[match.start() : match.end()] + last_end_idx = match.end() + + return message_list + + @torch.inference_mode() def generate_stream_chatglm( model, @@ -54,7 +79,17 @@ def generate_stream_chatglm( max_new_tokens = int(params.get("max_new_tokens", 256)) echo = params.get("echo", True) - inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + model_type = str(type(model)).lower() + if "peft" in model_type: + model_type = str(type(model.base_model.model)).lower() + + if "chatglm3" in model_type: + message_list = recover_message_list(prompt) + inputs = tokenizer.build_chat_input( + query=message_list[-1]["content"], history=message_list[:-1], role="user" + ).to(model.device) + else: + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) input_echo_len = len(inputs["input_ids"][0]) gen_kwargs = {