Skip to content

Commit

Permalink
✨feat: supports to check the content safety of LLM output #474
Browse files Browse the repository at this point in the history
  • Loading branch information
Soulter committed Feb 11, 2025
1 parent 9fa00af commit 43cd34d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
6 changes: 6 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
}
},
"content_safety": {
"also_use_in_response": False,
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
},
Expand Down Expand Up @@ -310,6 +311,11 @@
"description": "内容安全",
"type": "object",
"items": {
"also_use_in_response": {
"description": "对大模型响应安全审核",
"type": "bool",
"hint": "启用后,大模型的响应也会通过内容安全审核。",
},
"baidu_aip": {
"description": "百度内容审核配置",
"type": "object",
Expand Down
6 changes: 4 additions & 2 deletions astrbot/core/pipeline/content_safety_check/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ async def initialize(self, ctx: PipelineContext):
config = ctx.astrbot_config['content_safety']
self.strategy_selector = StrategySelector(config)

async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]:
'''检查内容安全'''
ok, info = self.strategy_selector.check(event.get_message_str())
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
if not ok:
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
yield
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
Expand Down
24 changes: 21 additions & 3 deletions astrbot/core/pipeline/result_decorate/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
Expand All @@ -12,7 +12,7 @@
from astrbot.core.star.star_handler import star_handlers_registry, EventType

@register_stage
class ResultDecorateStage:
class ResultDecorateStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
Expand All @@ -30,12 +30,30 @@ async def initialize(self, ctx: PipelineContext):
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']


# exception
self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response']
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage


async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return

# 回复时检查内容安全
if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result():
text = ""
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield

handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
await handler.handler(event)
Expand Down

0 comments on commit 43cd34d

Please sign in to comment.