Skip to content

Commit

Permalink
feat: 支持自定义文转图服务地址
Browse files Browse the repository at this point in the history
  • Loading branch information
Soulter committed Sep 22, 2024
1 parent 90815b1 commit 353b6ed
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 2 deletions.
6 changes: 6 additions & 0 deletions astrbot/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def __init__(self) -> None:
logger.info("未使用代理。")

self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'

# set t2i endpoint
if self.context.config_helper.t2i_endpoint:
self.context.image_renderer.set_network_endpoint(
self.context.config_helper.t2i_endpoint
)

async def run(self):
self.command_manager = CommandManager()
Expand Down
2 changes: 2 additions & 0 deletions type/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
"password": "",
},
"log_level": "INFO",
"t2i_endpoint": "",
}

# 这个是用于迁移旧版本配置文件的映射表
Expand Down Expand Up @@ -352,4 +353,5 @@
}
},
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
}
2 changes: 2 additions & 0 deletions util/cmd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class AstrBotConfig():
platform: List[PlatformConfig] = field(default_factory=list)
wake_prefix: List[str] = field(default_factory=list)
log_level: str = "INFO"
t2i_endpoint: str = ""

def __init__(self) -> None:
self.init_configs()
Expand Down Expand Up @@ -176,6 +177,7 @@ def load_from_dict(self, data: Dict):
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
self.wake_prefix=data.get("wake_prefix", [])
self.log_level=data.get("log_level", "INFO")
self.t2i_endpoint=data.get("t2i_endpoint", "")

def migrate_config_1_2(self, old: dict) -> dict:
'''将配置文件从版本 1 迁移至版本 2'''
Expand Down
10 changes: 8 additions & 2 deletions util/t2i/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
logger: Logger = LogManager.GetLogger(log_name='astrbot')

class TextToImageRenderer:
def __init__(self):
self.network_strategy = NetworkRenderStrategy()
def __init__(self, endpoint_url: str = None):
self.network_strategy = NetworkRenderStrategy(endpoint_url)
self.local_strategy = LocalRenderStrategy()
self.context = RenderContext(self.network_strategy)

def set_network_endpoint(self, endpoint_url: str):
'''设置 t2i 的网络端点。
'''
logger.info("文本转图像服务接口: " + endpoint_url)
self.network_strategy.set_endpoint(endpoint_url)

async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False):
'''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
@param tmpl_str: HTML Jinja2 模板。
Expand Down
7 changes: 7 additions & 0 deletions util/t2i/strategies/network_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@
class NetworkRenderStrategy(RenderStrategy):
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
super().__init__()
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")

def set_endpoint(self, base_url: str):
if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
self.BASE_RENDER_URL = base_url

async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
'''使用自定义文转图模板'''
post_data = {
Expand Down

0 comments on commit 353b6ed

Please sign in to comment.