Skip to content

Commit

Permalink
safe format prompt variables in strings with JSON (#15734)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Sep 22, 2024
1 parent ac4e7e4 commit 7d9bd0f
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 101 deletions.
42 changes: 4 additions & 38 deletions llama-index-core/llama_index/core/llms/structured_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,40 +39,6 @@
OutputKeys,
QueryComponent,
)
import re


def _escape_braces(text: str) -> str:
"""
Escape braces in text.
Only captures template variables, skips already escaped braces.
"""

def replace(match: re.Match[str]) -> str:
if match.group(0).startswith("{{") and match.group(0).endswith("}}"):
return match.group(0) # Already escaped, return as is
return "{{" + match.group(1) + "}}"

pattern = r"(?<!\{)\{([^{}]+?)\}(?!\})"
return re.sub(pattern, replace, text)


def _escape_json(messages: Sequence[ChatMessage]) -> Sequence[ChatMessage]:
"""Escape JSON in messages."""
new_messages = []
for message in messages:
if isinstance(message.content, str):
escaped_msg = _escape_braces(message.content)
new_messages.append(
ChatMessage(
role=message.role,
content=escaped_msg,
additional_kwargs=message.additional_kwargs,
)
)
else:
new_messages.append(message)
return new_messages


class StructuredLLM(LLM):
Expand Down Expand Up @@ -104,7 +70,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
# make this work with our FunctionCallingProgram, even though
# the messages don't technically have any variables (they are already formatted)

chat_prompt = ChatPromptTemplate(message_templates=_escape_json(messages))
chat_prompt = ChatPromptTemplate(message_templates=messages)

output = self.llm.structured_predict(
output_cls=self.output_cls, prompt=chat_prompt, llm_kwargs=kwargs
Expand All @@ -120,7 +86,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
chat_prompt = ChatPromptTemplate(message_templates=_escape_json(messages))
chat_prompt = ChatPromptTemplate(message_templates=messages)

stream_output = self.llm.stream_structured_predict(
output_cls=self.output_cls, prompt=chat_prompt, llm_kwargs=kwargs
Expand Down Expand Up @@ -158,7 +124,7 @@ async def achat(
# make this work with our FunctionCallingProgram, even though
# the messages don't technically have any variables (they are already formatted)

chat_prompt = ChatPromptTemplate(message_templates=_escape_json(messages))
chat_prompt = ChatPromptTemplate(message_templates=messages)

output = await self.llm.astructured_predict(
output_cls=self.output_cls, prompt=chat_prompt, llm_kwargs=kwargs
Expand All @@ -179,7 +145,7 @@ async def astream_chat(
"""Async stream chat endpoint for LLM."""

async def gen() -> ChatResponseAsyncGen:
chat_prompt = ChatPromptTemplate(message_templates=_escape_json(messages))
chat_prompt = ChatPromptTemplate(message_templates=messages)

stream_output = await self.llm.astream_structured_predict(
output_cls=self.output_cls, prompt=chat_prompt, llm_kwargs=kwargs
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
prompt_to_messages,
)
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.prompts.utils import get_template_vars
from llama_index.core.prompts.utils import get_template_vars, format_string
from llama_index.core.types import BaseOutputParser


Expand Down Expand Up @@ -205,7 +205,7 @@ def format(
}

mapped_all_kwargs = self._map_all_vars(all_kwargs)
prompt = self.template.format(**mapped_all_kwargs)
prompt = format_string(self.template, **mapped_all_kwargs)

if self.output_parser is not None:
prompt = self.output_parser.format(prompt)
Expand Down Expand Up @@ -313,7 +313,7 @@ def format_messages(
content_template = message_template.content or ""

# if there's mappings specified, make sure those are used
content = content_template.format(**relevant_kwargs)
content = format_string(content_template, **relevant_kwargs)

message: ChatMessage = message_template.model_copy()
message.content = content
Expand Down
31 changes: 27 additions & 4 deletions llama-index-core/llama_index/core/prompts/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from string import Formatter
from typing import List
from typing import Dict, List, Optional
import re

from llama_index.core.base.llms.base import BaseLLM


class SafeFormatter:
"""Safe string formatter that does not raise KeyError if key is missing."""

def __init__(self, format_dict: Optional[Dict[str, str]] = None):
self.format_dict = format_dict or {}

def format(self, format_string: str) -> str:
return re.sub(r"\{([^{}]+)\}", self._replace_match, format_string)

def parse(self, format_string: str) -> List[str]:
return re.findall(r"\{([^{}]+)\}", format_string)

def _replace_match(self, match: re.Match) -> str:
key = match.group(1)
return str(self.format_dict.get(key, match.group(0)))


def format_string(string_to_format: str, **kwargs: str) -> str:
"""Format a string with kwargs."""
formatter = SafeFormatter(format_dict=kwargs)
return formatter.format(string_to_format)


def get_template_vars(template_str: str) -> List[str]:
"""Get template variables from a template string."""
variables = []
formatter = Formatter()
formatter = SafeFormatter()

for _, variable_name, _, _ in formatter.parse(template_str):
for variable_name in formatter.parse(template_str):
if variable_name:
variables.append(variable_name)

Expand Down
56 changes: 0 additions & 56 deletions llama-index-core/tests/llms/test_structured_llm.py

This file was deleted.

18 changes: 18 additions & 0 deletions llama-index-core/tests/prompts/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,21 @@ def _format_prompt_key1(**kwargs: Any) -> str:
"user: hello tmp1-tmp2 tmp2\n"
"assistant: "
)


def test_template_with_json() -> None:
"""Test partial format."""
prompt_txt = 'hello {text} {foo} {"bar": "baz"}'
prompt = PromptTemplate(prompt_txt)

assert prompt.format(foo="foo2", text="world") == 'hello world foo2 {"bar": "baz"}'

assert prompt.format_messages(foo="foo2", text="world") == [
ChatMessage(content='hello world foo2 {"bar": "baz"}', role=MessageRole.USER)
]

test_case_2 = PromptTemplate("test {message} {test}")
assert test_case_2.format(message="message") == "test message {test}"

test_case_3 = PromptTemplate("test {{message}} {{test}}")
assert test_case_3.format(message="message", test="test") == "test {message} {test}"

0 comments on commit 7d9bd0f

Please sign in to comment.