Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

支持插件注册消息中间件 #202

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions addons/plugins/helloworld/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
flag_not_support = False
try:
from util.plugin_dev.api.v1.bot import Context, AstrMessageEvent, CommandResult
from util.plugin_dev.api.v1.config import *
from util.plugin_dev.api.v1 import (
Context,
CommandResult,
AstrMessageEvent,
Middleware,
)
except ImportError:
flag_not_support = True
print("导入接口失败。请升级到 AstrBot 最新版本。")
Expand All @@ -21,12 +25,12 @@ class HelloWorldPlugin:
def __init__(self, context: Context) -> None:
self.context = context
self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld)

"""
指令处理函数。

- 需要接收两个参数:message: AstrMessageEvent, context: Context
- 返回 CommandResult 对象
"""
def helloworld(self, message: AstrMessageEvent, context: Context):
async def helloworld(self, message: AstrMessageEvent, context: Context):
return CommandResult().message("Hello, World!")
12 changes: 10 additions & 2 deletions astrbot/message/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,20 @@ async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None)
is_command_call=True,
use_t2i=cmd_res.is_use_t2i
)

# next is the LLM part

# middlewares
for middleware in self.context.middlewares:
try:
logger.info(f"执行中间件 {middleware.origin}/{middleware.name}...")
await middleware.func(message, self.context)
except BaseException as e:
logger.error(f"中间件 {middleware.origin}/{middleware.name} 处理消息时发生异常:{e},跳过。")
logger.error(traceback.format_exc())

if message.only_command:
return

# next is the LLM part
# check if the message is a llm-wake-up command
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
Expand Down
6 changes: 3 additions & 3 deletions model/command/internal_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,18 @@ async def help(self, message: AstrMessageEvent, context: Context):
except BaseException as e:
logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.")

msg = "# Help Center\n## 指令列表\n"
msg = "# 帮助中心\n## 指令\n"
for key, value in self.manager.commands_handler.items():
if value.plugin_metadata:
msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n"
else: msg += f"- `{key}`: {value.description}\n"
# plugins
if context.cached_plugins != None:
if context.cached_plugins:
plugin_list_info = ""
for plugin in context.cached_plugins:
plugin_list_info += f"- `{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n"
if plugin_list_info.strip() != "":
msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n"
msg += "\n## 插件\n> 使用plugin v 插件名 查看插件帮助\n"
msg += plugin_list_info
msg += notice

Expand Down
8 changes: 8 additions & 0 deletions type/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

@dataclass
class Middleware():
name: str = ""
description: str = ""
origin: str = "" # 注册来源
func: callable = None
10 changes: 10 additions & 0 deletions type/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.middleware import Middleware
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self):
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
self.middlewares: List[Middleware] = []

self.command_manager = None
self.running = True
Expand Down Expand Up @@ -115,6 +117,14 @@ def unregister_llm_tool(self, tool_name: str):
删除一个函数调用工具。
'''
self.message_handler.llm_tools.remove_func(tool_name)

def register_middleware(self, middleware: Middleware):
'''
注册一个中间件。所有的消息事件都会经过中间件处理,然后再进入 LLM 聊天模块。

在 AstrBot 中,会对到来的消息事件首先检查指令,然后再检查中间件。触发指令后将不会进入 LLM 聊天模块,而中间件会。
'''
self.middlewares.append(middleware)

def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
Expand Down
7 changes: 7 additions & 0 deletions util/plugin_dev/api/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .bot import *
from .config import *
from .llm import *
from .message import *
from .platform import *
from .register import *
from .types import *
2 changes: 1 addition & 1 deletion util/plugin_dev/api/v1/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
'''

from type.plugin import PluginType

from type.middleware import Middleware
from nakuru.entities.components import Image, Plain, At, Node, BaseMessageComponent
Loading