Merge main into pr-1482

This commit is contained in:
Re-bin
2026-03-07 15:33:24 +00:00
40 changed files with 2355 additions and 325 deletions

View File

@@ -10,6 +10,7 @@ from typing import Any
from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import detect_image_mime
class ContextBuilder:
@@ -136,10 +137,14 @@ Reply directly with text for conversations. Only use the 'message' tool to send
images = []
for path in media:
p = Path(path)
mime, _ = mimetypes.guess_type(path)
if not p.is_file() or not mime or not mime.startswith("image/"):
if not p.is_file():
continue
b64 = base64.b64encode(p.read_bytes()).decode()
raw = p.read_bytes()
# Detect real MIME type from magic bytes; fallback to filename guess
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
if not mime or not mime.startswith("image/"):
continue
b64 = base64.b64encode(raw).decode()
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
if not images:

View File

@@ -202,9 +202,9 @@ class AgentLoop:
if response.has_tool_calls:
if on_progress:
clean = self._strip_think(response.content)
if clean:
await on_progress(clean)
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [

View File

@@ -128,6 +128,13 @@ class MemoryStore:
# Some providers return arguments as a JSON string instead of dict
if isinstance(args, str):
args = json.loads(args)
# Some providers return arguments as a list (handle edge case)
if isinstance(args, list):
if args and isinstance(args[0], dict):
args = args[0]
else:
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
return False
if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False

View File

@@ -52,8 +52,79 @@ class Tool(ABC):
"""
pass
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Apply safe schema-driven casts before validation."""
schema = self.parameters or {}
if schema.get("type", "object") != "object":
return params
return self._cast_object(params, schema)
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
"""Cast an object (dict) according to schema."""
if not isinstance(obj, dict):
return obj
props = schema.get("properties", {})
result = {}
for key, value in obj.items():
if key in props:
result[key] = self._cast_value(value, props[key])
else:
result[key] = value
return result
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
"""Cast a single value according to schema."""
target_type = schema.get("type")
if target_type == "boolean" and isinstance(val, bool):
return val
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
return val
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
expected = self._TYPE_MAP[target_type]
if isinstance(val, expected):
return val
if target_type == "integer" and isinstance(val, str):
try:
return int(val)
except ValueError:
return val
if target_type == "number" and isinstance(val, str):
try:
return float(val)
except ValueError:
return val
if target_type == "string":
return val if val is None else str(val)
if target_type == "boolean" and isinstance(val, str):
val_lower = val.lower()
if val_lower in ("true", "1", "yes"):
return True
if val_lower in ("false", "0", "no"):
return False
return val
if target_type == "array" and isinstance(val, list):
item_schema = schema.get("items")
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
if target_type == "object" and isinstance(val, dict):
return self._cast_object(val, schema)
return val
def validate_params(self, params: dict[str, Any]) -> list[str]:
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
if not isinstance(params, dict):
return [f"parameters must be an object, got {type(params).__name__}"]
schema = self.parameters or {}
if schema.get("type", "object") != "object":
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
@@ -61,7 +132,13 @@ class Tool(ABC):
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
return [f"{label} should be integer"]
if t == "number" and (
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
):
return [f"{label} should be number"]
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"]
errors = []

View File

@@ -122,7 +122,10 @@ class CronTool(Tool):
elif at:
from datetime import datetime
dt = datetime.fromisoformat(at)
try:
dt = datetime.fromisoformat(at)
except ValueError:
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
at_ms = int(dt.timestamp() * 1000)
schedule = CronSchedule(kind="at", at_ms=at_ms)
delete_after = True

View File

@@ -26,6 +26,8 @@ def _resolve_path(
class ReadFileTool(Tool):
"""Tool to read file contents."""
_MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace
self._allowed_dir = allowed_dir
@@ -54,7 +56,16 @@ class ReadFileTool(Tool):
if not file_path.is_file():
return f"Error: Not a file: {path}"
size = file_path.stat().st_size
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes)
return (
f"Error: File too large ({size:,} bytes). "
f"Use exec tool with head/tail/grep to read portions."
)
content = file_path.read_text(encoding="utf-8")
if len(content) > self._MAX_CHARS:
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
return content
except PermissionError as e:
return f"Error: {e}"

View File

@@ -58,17 +58,48 @@ async def connect_mcp_servers(
) -> None:
"""Connect to configured MCP servers and register their tools."""
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items():
try:
if cfg.command:
transport_type = cfg.type
if not transport_type:
if cfg.command:
transport_type = "stdio"
elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
if transport_type == "stdio":
params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None
)
read, write = await stack.enter_async_context(stdio_client(params))
elif cfg.url:
from mcp.client.streamable_http import streamable_http_client
elif transport_type == "sse":
def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {**(cfg.headers or {}), **(headers or {})}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,
timeout=timeout,
auth=auth,
)
read, write = await stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context(
@@ -82,7 +113,7 @@ async def connect_mcp_servers(
streamable_http_client(cfg.url, http_client=http_client)
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue
session = await stack.enter_async_context(ClientSession(read, write))

View File

@@ -96,7 +96,7 @@ class MessageTool(Tool):
media=media or [],
metadata={
"message_id": message_id,
}
},
)
try:

View File

@@ -44,6 +44,10 @@ class ToolRegistry:
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params)
if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT

View File

@@ -13,34 +13,13 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import DiscordConfig
from nanobot.utils.helpers import split_message
DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
"""Split content into chunks within max_len, preferring line breaks."""
if not content:
return []
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket."""
@@ -54,6 +33,7 @@ class DiscordChannel(BaseChannel):
self._heartbeat_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._http: httpx.AsyncClient | None = None
self._bot_user_id: str | None = None
async def start(self) -> None:
"""Start the Discord gateway connection."""
@@ -95,7 +75,7 @@ class DiscordChannel(BaseChannel):
self._http = None
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API."""
"""Send a message through Discord REST API, including file attachments."""
if not self._http:
logger.warning("Discord HTTP client not initialized")
return
@@ -104,15 +84,31 @@ class DiscordChannel(BaseChannel):
headers = {"Authorization": f"Bot {self.config.token}"}
try:
chunks = _split_message(msg.content or "")
sent_media = False
failed_media: list[str] = []
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks:
return
for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk}
# Only set reply reference on the first chunk
if i == 0 and msg.reply_to:
# Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
@@ -143,6 +139,54 @@ class DiscordChannel(BaseChannel):
await asyncio.sleep(1)
return False
async def _send_file(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events."""
if not self._ws:
@@ -170,6 +214,10 @@ class DiscordChannel(BaseChannel):
await self._identify()
elif op == 0 and event_type == "READY":
logger.info("Discord gateway READY")
# Capture bot user ID for mention detection
user_data = payload.get("user") or {}
self._bot_user_id = user_data.get("id")
logger.info("Discord bot connected as user {}", self._bot_user_id)
elif op == 0 and event_type == "MESSAGE_CREATE":
await self._handle_message_create(payload)
elif op == 7:
@@ -226,6 +274,7 @@ class DiscordChannel(BaseChannel):
sender_id = str(author.get("id", ""))
channel_id = str(payload.get("channel_id", ""))
content = payload.get("content") or ""
guild_id = payload.get("guild_id")
if not sender_id or not channel_id:
return
@@ -233,6 +282,11 @@ class DiscordChannel(BaseChannel):
if not self.is_allowed(sender_id):
return
# Check group channel policy (DMs always respond if is_allowed passes)
if guild_id is not None:
if not self._should_respond_in_group(payload, content):
return
content_parts = [content] if content else []
media_paths: list[str] = []
media_dir = Path.home() / ".nanobot" / "media"
@@ -269,11 +323,32 @@ class DiscordChannel(BaseChannel):
media=media_paths,
metadata={
"message_id": str(payload.get("id", "")),
"guild_id": payload.get("guild_id"),
"guild_id": guild_id,
"reply_to": reply_to,
},
)
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
"""Check if bot should respond in a group channel based on policy."""
if self.config.group_policy == "open":
return True
if self.config.group_policy == "mention":
# Check if bot was mentioned in the message
if self._bot_user_id:
# Check mentions array
mentions = payload.get("mentions") or []
for mention in mentions:
if str(mention.get("id")) == self._bot_user_id:
return True
# Also check content for mention format <@USER_ID>
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
return False
return True
async def _start_typing(self, channel_id: str) -> None:
"""Start periodic typing indicator for a channel."""
await self._stop_typing(channel_id)

View File

@@ -16,26 +16,9 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import FeishuConfig
try:
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
CreateFileRequest,
CreateFileRequestBody,
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
CreateMessageRequest,
CreateMessageRequestBody,
Emoji,
GetMessageResourceRequest,
P2ImMessageReceiveV1,
)
FEISHU_AVAILABLE = True
except ImportError:
FEISHU_AVAILABLE = False
lark = None
Emoji = None
import importlib.util
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
# Message type display mapping
MSG_TYPE_MAP = {
@@ -261,15 +244,22 @@ class FeishuChannel(BaseChannel):
name = "feishu"
def __init__(self, config: FeishuConfig, bus: MessageBus):
def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""):
super().__init__(config, bus)
self.config: FeishuConfig = config
self.groq_api_key = groq_api_key
self._client: Any = None
self._ws_client: Any = None
self._ws_thread: threading.Thread | None = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
self._loop: asyncio.AbstractEventLoop | None = None
@staticmethod
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
"""Register an event handler only when the SDK supports it."""
method = getattr(builder, method_name, None)
return method(handler) if callable(method) else builder
async def start(self) -> None:
"""Start the Feishu bot with WebSocket long connection."""
if not FEISHU_AVAILABLE:
@@ -280,6 +270,7 @@ class FeishuChannel(BaseChannel):
logger.error("Feishu app_id and app_secret not configured")
return
import lark_oapi as lark
self._running = True
self._loop = asyncio.get_running_loop()
@@ -289,14 +280,24 @@ class FeishuChannel(BaseChannel):
.app_secret(self.config.app_secret) \
.log_level(lark.LogLevel.INFO) \
.build()
# Create event handler (only register message receive, ignore other events)
event_handler = lark.EventDispatcherHandler.builder(
builder = lark.EventDispatcherHandler.builder(
self.config.encrypt_key or "",
self.config.verification_token or "",
).register_p2_im_message_receive_v1(
self._on_message_sync
).build()
)
builder = self._register_optional_event(
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
)
builder = self._register_optional_event(
builder, "register_p2_im_message_message_read_v1", self._on_message_read
)
builder = self._register_optional_event(
builder,
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
self._on_bot_p2p_chat_entered,
)
event_handler = builder.build()
# Create WebSocket client for long connection
self._ws_client = lark.ws.Client(
@@ -306,16 +307,28 @@ class FeishuChannel(BaseChannel):
log_level=lark.LogLevel.INFO
)
# Start WebSocket client in a separate thread with reconnect loop
# Start WebSocket client in a separate thread with reconnect loop.
# A dedicated event loop is created for this thread so that lark_oapi's
# module-level `loop = asyncio.get_event_loop()` picks up an idle loop
# instead of the already-running main asyncio loop, which would cause
# "This event loop is already running" errors.
def run_ws():
while self._running:
try:
self._ws_client.start()
except Exception as e:
logger.warning("Feishu WebSocket error: {}", e)
if self._running:
import time
time.sleep(5)
import time
import lark_oapi.ws.client as _lark_ws_client
ws_loop = asyncio.new_event_loop()
asyncio.set_event_loop(ws_loop)
# Patch the module-level loop used by lark's ws Client.start()
_lark_ws_client.loop = ws_loop
try:
while self._running:
try:
self._ws_client.start()
except Exception as e:
logger.warning("Feishu WebSocket error: {}", e)
if self._running:
time.sleep(5)
finally:
ws_loop.close()
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
self._ws_thread.start()
@@ -340,6 +353,7 @@ class FeishuChannel(BaseChannel):
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
try:
request = CreateMessageReactionRequest.builder() \
.message_id(message_id) \
@@ -364,7 +378,7 @@ class FeishuChannel(BaseChannel):
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
"""
if not self._client or not Emoji:
if not self._client:
return
loop = asyncio.get_running_loop()
@@ -413,6 +427,34 @@ class FeishuChannel(BaseChannel):
elements.extend(self._split_headings(remaining))
return elements or [{"tag": "markdown", "content": content}]
@staticmethod
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
"""Split card elements into groups with at most *max_tables* table elements each.
Feishu cards have a hard limit of one table per card (API error 11310).
When the rendered content contains multiple markdown tables each table is
placed in a separate card message so every table reaches the user.
"""
if not elements:
return [[]]
groups: list[list[dict]] = []
current: list[dict] = []
table_count = 0
for el in elements:
if el.get("tag") == "table":
if table_count >= max_tables:
if current:
groups.append(current)
current = []
table_count = 0
current.append(el)
table_count += 1
else:
current.append(el)
if current:
groups.append(current)
return groups or [[]]
def _split_headings(self, content: str) -> list[dict]:
"""Split content by headings, converting headings to div elements."""
protected = content
@@ -447,8 +489,124 @@ class FeishuChannel(BaseChannel):
return elements or [{"tag": "markdown", "content": content}]
# ── Smart format detection ──────────────────────────────────────────
# Patterns that indicate "complex" markdown needing card rendering
_COMPLEX_MD_RE = re.compile(
r"```" # fenced code block
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
r"|^#{1,6}\s+" # headings
, re.MULTILINE,
)
# Simple markdown patterns (bold, italic, strikethrough)
_SIMPLE_MD_RE = re.compile(
r"\*\*.+?\*\*" # **bold**
r"|__.+?__" # __bold__
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
r"|~~.+?~~" # ~~strikethrough~~
, re.DOTALL,
)
# Markdown link: [text](url)
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
# Unordered list items
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
# Ordered list items
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
# Max length for plain text format
_TEXT_MAX_LEN = 200
# Max length for post (rich text) format; beyond this, use card
_POST_MAX_LEN = 2000
@classmethod
def _detect_msg_format(cls, content: str) -> str:
"""Determine the optimal Feishu message format for *content*.
Returns one of:
- ``"text"`` plain text, short and no markdown
- ``"post"`` rich text (links only, moderate length)
- ``"interactive"`` card with full markdown rendering
"""
stripped = content.strip()
# Complex markdown (code blocks, tables, headings) → always card
if cls._COMPLEX_MD_RE.search(stripped):
return "interactive"
# Long content → card (better readability with card layout)
if len(stripped) > cls._POST_MAX_LEN:
return "interactive"
# Has bold/italic/strikethrough → card (post format can't render these)
if cls._SIMPLE_MD_RE.search(stripped):
return "interactive"
# Has list items → card (post format can't render list bullets well)
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
return "interactive"
# Has links → post format (supports <a> tags)
if cls._MD_LINK_RE.search(stripped):
return "post"
# Short plain text → text format
if len(stripped) <= cls._TEXT_MAX_LEN:
return "text"
# Medium plain text without any formatting → post format
return "post"
@classmethod
def _markdown_to_post(cls, content: str) -> str:
"""Convert markdown content to Feishu post message JSON.
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
Each line becomes a paragraph (row) in the post body.
"""
lines = content.strip().split("\n")
paragraphs: list[list[dict]] = []
for line in lines:
elements: list[dict] = []
last_end = 0
for m in cls._MD_LINK_RE.finditer(line):
# Text before this link
before = line[last_end:m.start()]
if before:
elements.append({"tag": "text", "text": before})
elements.append({
"tag": "a",
"text": m.group(1),
"href": m.group(2),
})
last_end = m.end()
# Remaining text after last link
remaining = line[last_end:]
if remaining:
elements.append({"tag": "text", "text": remaining})
# Empty line → empty paragraph for spacing
if not elements:
elements.append({"tag": "text", "text": ""})
paragraphs.append(elements)
post_body = {
"zh_cn": {
"content": paragraphs,
}
}
return json.dumps(post_body, ensure_ascii=False)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
_AUDIO_EXTS = {".opus"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
_FILE_TYPE_MAP = {
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
@@ -456,6 +614,7 @@ class FeishuChannel(BaseChannel):
def _upload_image_sync(self, file_path: str) -> str | None:
"""Upload an image to Feishu and return the image_key."""
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
try:
with open(file_path, "rb") as f:
request = CreateImageRequest.builder() \
@@ -479,6 +638,7 @@ class FeishuChannel(BaseChannel):
def _upload_file_sync(self, file_path: str) -> str | None:
"""Upload a file to Feishu and return the file_key."""
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
ext = os.path.splitext(file_path)[1].lower()
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
file_name = os.path.basename(file_path)
@@ -506,6 +666,7 @@ class FeishuChannel(BaseChannel):
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
"""Download an image from Feishu message by message_id and image_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
try:
request = GetMessageResourceRequest.builder() \
.message_id(message_id) \
@@ -530,6 +691,13 @@ class FeishuChannel(BaseChannel):
self, message_id: str, file_key: str, resource_type: str = "file"
) -> tuple[bytes | None, str | None]:
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
# Feishu API only accepts 'image' or 'file' as type parameter
# Convert 'audio' to 'file' for API compatibility
if resource_type == "audio":
resource_type = "file"
try:
request = (
GetMessageResourceRequest.builder()
@@ -598,6 +766,7 @@ class FeishuChannel(BaseChannel):
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
"""Send a single message (text/image/file/interactive) synchronously."""
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
try:
request = CreateMessageRequest.builder() \
.receive_id_type(receive_id_type) \
@@ -646,23 +815,50 @@ class FeishuChannel(BaseChannel):
else:
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
if key:
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
# Use msg_type "media" for audio/video so users can play inline;
# "file" for everything else (documents, archives, etc.)
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
media_type = "media"
else:
media_type = "file"
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
)
if msg.content and msg.content.strip():
card = {"config": {"wide_screen_mode": True}, "elements": self._build_card_elements(msg.content)}
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
)
fmt = self._detect_msg_format(msg.content)
if fmt == "text":
# Short plain text send as simple text message
text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "text", text_body,
)
elif fmt == "post":
# Medium content with links send as rich-text post
post_body = self._markdown_to_post(msg.content)
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "post", post_body,
)
else:
# Complex / long content send as interactive card
elements = self._build_card_elements(msg.content)
for chunk in self._split_elements_by_table_limit(elements):
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
)
except Exception as e:
logger.error("Error sending Feishu message: {}", e)
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
def _on_message_sync(self, data: Any) -> None:
"""
Sync handler for incoming messages (called from WebSocket thread).
Schedules async handling in the main event loop.
@@ -670,7 +866,7 @@ class FeishuChannel(BaseChannel):
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
async def _on_message(self, data: Any) -> None:
"""Handle incoming message from Feishu."""
try:
event = data.event
@@ -730,6 +926,18 @@ class FeishuChannel(BaseChannel):
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
if file_path:
media_paths.append(file_path)
# Transcribe audio using Groq Whisper
if msg_type == "audio" and file_path and self.groq_api_key:
try:
from nanobot.providers.transcription import GroqTranscriptionProvider
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
transcription = await transcriber.transcribe(file_path)
if transcription:
content_text = f"[transcription: {transcription}]"
except Exception as e:
logger.warning("Failed to transcribe audio: {}", e)
content_parts.append(content_text)
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
@@ -762,3 +970,16 @@ class FeishuChannel(BaseChannel):
except Exception as e:
logger.error("Error processing Feishu message: {}", e)
def _on_reaction_created(self, data: Any) -> None:
"""Ignore reaction events so they do not generate SDK noise."""
pass
def _on_message_read(self, data: Any) -> None:
"""Ignore read events so they do not generate SDK noise."""
pass
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
"""Ignore p2p-enter events when a user opens a bot chat."""
logger.debug("Bot entered p2p chat (user opened chat window)")
pass

View File

@@ -74,7 +74,8 @@ class ChannelManager:
try:
from nanobot.channels.feishu import FeishuChannel
self.channels["feishu"] = FeishuChannel(
self.config.channels.feishu, self.bus
self.config.channels.feishu, self.bus,
groq_api_key=self.config.providers.groq.api_key,
)
logger.info("Feishu channel enabled")
except ImportError as e:

View File

@@ -56,6 +56,7 @@ class QQChannel(BaseChannel):
self.config: QQConfig = config
self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000)
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
async def start(self) -> None:
"""Start the QQ bot."""
@@ -102,11 +103,13 @@ class QQChannel(BaseChannel):
return
try:
msg_id = msg.metadata.get("message_id")
self._msg_seq += 1 # 递增序列号
await self._client.api.post_c2c_message(
openid=msg.chat_id,
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq, # 添加序列号避免去重
)
except Exception as e:
logger.error("Error sending QQ message: {}", e)
@@ -133,3 +136,4 @@ class QQChannel(BaseChannel):
)
except Exception:
logger.exception("Error handling QQ message")

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio
import re
import time
import unicodedata
from loguru import logger
from telegram import BotCommand, ReplyParameters, Update
@@ -14,6 +16,50 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
def _strip_md(s: str) -> str:
"""Strip markdown inline formatting from text."""
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
s = re.sub(r'__(.+?)__', r'\1', s)
s = re.sub(r'~~(.+?)~~', r'\1', s)
s = re.sub(r'`([^`]+)`', r'\1', s)
return s.strip()
def _render_table_box(table_lines: list[str]) -> str:
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
def dw(s: str) -> int:
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
rows: list[list[str]] = []
has_sep = False
for line in table_lines:
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
has_sep = True
continue
rows.append(cells)
if not rows or not has_sep:
return '\n'.join(table_lines)
ncols = max(len(r) for r in rows)
for r in rows:
r.extend([''] * (ncols - len(r)))
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
def dr(cells: list[str]) -> str:
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
out = [dr(rows[0])]
out.append(' '.join('' * w for w in widths))
for row in rows[1:]:
out.append(dr(row))
return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str:
@@ -31,6 +77,27 @@ def _markdown_to_telegram_html(text: str) -> str:
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
lines = text.split('\n')
rebuilt: list[str] = []
li = 0
while li < len(lines):
if re.match(r'^\s*\|.+\|', lines[li]):
tbl: list[str] = []
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
tbl.append(lines[li])
li += 1
box = _render_table_box(tbl)
if box != '\n'.join(tbl):
code_blocks.append(box)
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
else:
rebuilt.extend(tbl)
else:
rebuilt.append(lines[li])
li += 1
text = '\n'.join(rebuilt)
# 2. Extract and protect inline code
inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str:
@@ -79,26 +146,6 @@ def _markdown_to_telegram_html(text: str) -> str:
return text
def _split_message(content: str, max_len: int = 4000) -> list[str]:
"""Split content into chunks within max_len, preferring line breaks."""
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
pos = cut.rfind('\n')
if pos == -1:
pos = cut.rfind(' ')
if pos == -1:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
class TelegramChannel(BaseChannel):
"""
Telegram channel using long polling.
@@ -130,6 +177,7 @@ class TelegramChannel(BaseChannel):
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {}
self._message_threads: dict[tuple[str, int], int] = {}
async def start(self) -> None:
"""Start the Telegram bot with long polling."""
@@ -140,16 +188,21 @@ class TelegramChannel(BaseChannel):
self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
req = HTTPXRequest(
connection_pool_size=16,
pool_timeout=5.0,
connect_timeout=30.0,
read_timeout=30.0,
proxy=self.config.proxy if self.config.proxy else None,
)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
if self.config.proxy:
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build()
self._app.add_error_handler(self._on_error)
# Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents
@@ -225,21 +278,25 @@ class TelegramChannel(BaseChannel):
logger.warning("Telegram bot not running")
return
self._stop_typing(msg.chat_id)
# Only stop typing indicator for final responses
if not msg.metadata.get("_progress", False):
self._stop_typing(msg.chat_id)
try:
chat_id = int(msg.chat_id)
except ValueError:
logger.error("Invalid chat_id: {}", msg.chat_id)
return
reply_to_message_id = msg.metadata.get("message_id")
message_thread_id = msg.metadata.get("message_thread_id")
if message_thread_id is None and reply_to_message_id is not None:
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
thread_kwargs = {}
if message_thread_id is not None:
thread_kwargs["message_thread_id"] = message_thread_id
reply_params = None
if self.config.reply_to_message:
reply_to_message_id = msg.metadata.get("message_id")
if reply_to_message_id:
reply_params = ReplyParameters(
message_id=reply_to_message_id,
@@ -275,27 +332,65 @@ class TelegramChannel(BaseChannel):
# Send text content
if msg.content and msg.content != "[empty message]":
for chunk in _split_message(msg.content):
try:
html = _markdown_to_telegram_html(chunk)
await self._app.bot.send_message(
chat_id=chat_id,
text=html,
parse_mode="HTML",
reply_parameters=reply_params,
**thread_kwargs,
)
except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e)
try:
await self._app.bot.send_message(
chat_id=chat_id,
text=chunk,
reply_parameters=reply_params,
**thread_kwargs,
)
except Exception as e2:
logger.error("Error sending Telegram message: {}", e2)
is_progress = msg.metadata.get("_progress", False)
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
# Final response: simulate streaming via draft, then persist
if not is_progress:
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
else:
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
async def _send_text(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Send a plain text message with HTML fallback."""
try:
html = _markdown_to_telegram_html(text)
await self._app.bot.send_message(
chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e)
try:
await self._app.bot.send_message(
chat_id=chat_id,
text=text,
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e2:
logger.error("Error sending Telegram message: {}", e2)
async def _send_with_streaming(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Simulate streaming via send_message_draft, then persist with send_message."""
draft_id = int(time.time() * 1000) % (2**31)
try:
step = max(len(text) // 8, 40)
for i in range(step, len(text), step):
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text[:i],
)
await asyncio.sleep(0.04)
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text,
)
await asyncio.sleep(0.15)
except Exception:
pass
await self._send_text(chat_id, text, reply_params, thread_kwargs)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command."""
@@ -347,12 +442,23 @@ class TelegramChannel(BaseChannel):
"is_forum": bool(getattr(message.chat, "is_forum", False)),
}
def _remember_thread_context(self, message) -> None:
"""Cache topic thread id by chat/message id for follow-up replies."""
message_thread_id = getattr(message, "message_thread_id", None)
if message_thread_id is None:
return
key = (str(message.chat_id), message.message_id)
self._message_threads[key] = message_thread_id
if len(self._message_threads) > 1000:
self._message_threads.pop(next(iter(self._message_threads)))
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Forward slash commands to the bus for unified handling in AgentLoop."""
if not update.message or not update.effective_user:
return
message = update.message
user = update.effective_user
self._remember_thread_context(message)
await self._handle_message(
sender_id=self._sender_id(user),
chat_id=str(message.chat_id),
@@ -370,6 +476,7 @@ class TelegramChannel(BaseChannel):
user = update.effective_user
chat_id = message.chat_id
sender_id = self._sender_id(user)
self._remember_thread_context(message)
# Store chat_id for replies
self._chat_ids[sender_id] = chat_id

View File

@@ -2,6 +2,7 @@
import asyncio
import json
import mimetypes
from collections import OrderedDict
from loguru import logger
@@ -128,10 +129,22 @@ class WhatsAppChannel(BaseChannel):
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]"
# Extract media paths (images/documents/videos downloaded by the bridge)
media_paths = data.get("media") or []
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
if media_paths:
for p in media_paths:
mime, _ = mimetypes.guess_type(p)
media_type = "image" if mime and mime.startswith("image/") else "file"
media_tag = f"[{media_type}: {p}]"
content = f"{content}\n{media_tag}" if content else media_tag
await self._handle_message(
sender_id=sender_id,
chat_id=sender, # Use full LID for replies
content=content,
media=media_paths,
metadata={
"message_id": message_id,
"timestamp": data.get("timestamp"),

View File

@@ -7,6 +7,18 @@ import signal
import sys
from pathlib import Path
# Force UTF-8 encoding for Windows console
if sys.platform == "win32":
import locale
if sys.stdout.encoding != "utf-8":
os.environ["PYTHONIOENCODING"] = "utf-8"
# Re-open stdout/stderr with UTF-8 encoding
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML
@@ -200,9 +212,8 @@ def onboard():
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
from nanobot.providers.custom_provider import CustomProvider
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
@@ -213,6 +224,7 @@ def _make_provider(config: Config):
return OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
from nanobot.providers.custom_provider import CustomProvider
if provider_name == "custom":
return CustomProvider(
api_key=p.api_key if p else "no-key",
@@ -220,6 +232,21 @@ def _make_provider(config: Config):
default_model=model,
)
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
if provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
return AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
@@ -244,13 +271,15 @@ def _make_provider(config: Config):
@app.command()
def gateway(
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
):
"""Start the nanobot gateway."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager
from nanobot.config.loader import get_data_dir, load_config
from nanobot.config.loader import load_config
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
@@ -260,16 +289,20 @@ def gateway(
import logging
logging.basicConfig(level=logging.DEBUG)
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
config_path = Path(config) if config else None
config = load_config(config_path)
if workspace:
config.agents.defaults.workspace = workspace
config = load_config()
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation)
cron_store_path = get_data_dir() / "cron" / "jobs.json"
# Use workspace path for per-instance cron store
cron_store_path = config.workspace_path / "cron" / "jobs.json"
cron = CronService(cron_store_path)
# Create agent with cron service
@@ -511,12 +544,21 @@ def agent(
else:
cli_channel, cli_chat_id = "cli", session_id
def _exit_on_sigint(signum, frame):
def _handle_signal(signum, frame):
sig_name = signal.Signals(signum).name
_restore_terminal()
console.print("\nGoodbye!")
os._exit(0)
console.print(f"\nReceived {sig_name}, goodbye!")
sys.exit(0)
signal.signal(signal.SIGINT, _exit_on_sigint)
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
# SIGHUP is not available on Windows
if hasattr(signal, 'SIGHUP'):
signal.signal(signal.SIGHUP, _handle_signal)
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
# SIGPIPE is not available on Windows
if hasattr(signal, 'SIGPIPE'):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
async def run_interactive():
bus_task = asyncio.create_task(agent_loop.run())

View File

@@ -29,7 +29,9 @@ class TelegramConfig(Base):
enabled: bool = False
token: str = "" # Bot token from @BotFather
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
proxy: str | None = (
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
)
reply_to_message: bool = False # If true, bot replies quote the original message
@@ -42,7 +44,9 @@ class FeishuConfig(Base):
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
verification_token: str = "" # Verification Token for event subscription (optional)
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
react_emoji: str = (
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
)
class DingTalkConfig(Base):
@@ -62,6 +66,7 @@ class DiscordConfig(Base):
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
group_policy: Literal["mention", "open"] = "mention"
class MatrixConfig(Base):
@@ -72,9 +77,13 @@ class MatrixConfig(Base):
access_token: str = ""
user_id: str = "" # @bot:matrix.org
device_id: str = ""
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound).
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
sync_stop_grace_seconds: int = (
2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
)
max_media_bytes: int = (
20 * 1024 * 1024
) # Max attachment size accepted for Matrix media handling (inbound + outbound).
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
@@ -105,7 +114,9 @@ class EmailConfig(Base):
from_address: str = ""
# Behavior
auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent
auto_reply_enabled: bool = (
True # If false, inbound email is read but no automatic reply is sent
)
poll_interval_seconds: int = 30
mark_seen: bool = True
max_body_chars: int = 12000
@@ -183,27 +194,17 @@ class QQConfig(Base):
enabled: bool = False
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
allow_from: list[str] = Field(
default_factory=list
) # Allowed user openids (empty = public access)
class MatrixConfig(Base):
"""Matrix (Element) channel configuration."""
enabled: bool = False
homeserver: str = "https://matrix.org"
access_token: str = ""
user_id: str = "" # e.g. @bot:matrix.org
device_id: str = ""
e2ee_enabled: bool = True # end-to-end encryption support
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
class ChannelsConfig(Base):
"""Configuration for chat channels."""
send_progress: bool = True # stream agent's text progress to the channel
send_progress: bool = True # stream agent's text progress to the channel
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
@@ -222,7 +223,9 @@ class AgentDefaults(Base):
workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5"
provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
provider: str = (
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
)
max_tokens: int = 8192
temperature: float = 0.1
max_tool_iterations: int = 40
@@ -248,6 +251,7 @@ class ProvidersConfig(Base):
"""Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
@@ -260,8 +264,8 @@ class ProvidersConfig(Base):
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) API gateway
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -291,7 +295,9 @@ class WebSearchConfig(Base):
class WebToolsConfig(Base):
"""Web tools configuration."""
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
proxy: str | None = (
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
)
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
@@ -305,12 +311,13 @@ class ExecToolConfig(Base):
class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP)."""
type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
command: str = "" # Stdio: command to run (e.g. "npx")
args: list[str] = Field(default_factory=list) # Stdio: command arguments
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
url: str = "" # HTTP: streamable HTTP endpoint URL
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
tool_timeout: int = 30 # Seconds before a tool call is cancelled
url: str = "" # HTTP/SSE: endpoint URL
headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
tool_timeout: int = 30 # seconds before a tool call is cancelled
class ToolsConfig(Base):
@@ -336,7 +343,9 @@ class Config(BaseSettings):
"""Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser()
def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
def _match_provider(
self, model: str | None = None
) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS

View File

@@ -3,5 +3,6 @@
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]

View File

@@ -0,0 +1,210 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
from __future__ import annotations
import uuid
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
"""
def __init__(
self,
api_key: str = "",
api_base: str = "",
default_model: str = "gpt-5.2-chat",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
return payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@@ -87,6 +87,20 @@ class LLMProvider(ABC):
result.append(msg)
return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod
async def chat(
self,

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import uuid
from typing import Any
import json_repair
@@ -15,7 +16,12 @@ class CustomProvider(LLMProvider):
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
super().__init__(api_key, api_base)
self.default_model = default_model
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
# Keep affinity stable for this provider instance to improve backend cache locality.
self._client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
default_headers={"x-session-affinity": uuid.uuid4().hex},
)
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,

View File

@@ -1,5 +1,6 @@
"""LiteLLM provider implementation for multi-provider support."""
import hashlib
import os
import secrets
import string
@@ -8,6 +9,7 @@ from typing import Any
import json_repair
import litellm
from litellm import acompletion
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway
@@ -165,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
return _ANTHROPIC_EXTRA_KEYS
return frozenset()
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
id_map: dict[str, str] = {}
def map_id(value: Any) -> Any:
if not isinstance(value, str):
return value
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
normalized_tool_calls = []
for tc in clean["tool_calls"]:
if not isinstance(tc, dict):
normalized_tool_calls.append(tc)
continue
tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized_tool_calls.append(tc_clean)
clean["tool_calls"] = normalized_tool_calls
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat(
@@ -255,20 +283,37 @@ class LiteLLMProvider(LLMProvider):
"""Parse LiteLLM response into our standard format."""
choice = response.choices[0]
message = choice.message
content = message.content
finish_reason = choice.finish_reason
# Some providers (e.g. GitHub Copilot) split content and tool_calls
# across multiple choices. Merge them so tool_calls are not lost.
raw_tool_calls = []
for ch in response.choices:
msg = ch.message
if hasattr(msg, "tool_calls") and msg.tool_calls:
raw_tool_calls.extend(msg.tool_calls)
if ch.finish_reason in ("tool_calls", "stop"):
finish_reason = ch.finish_reason
if not content and msg.content:
content = msg.content
if len(response.choices) > 1:
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
len(response.choices), len(raw_tool_calls))
tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls:
# Parse arguments from JSON string if needed
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
for tc in raw_tool_calls:
# Parse arguments from JSON string if needed
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
))
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
))
usage = {}
if hasattr(response, "usage") and response.usage:
@@ -280,11 +325,11 @@ class LiteLLMProvider(LLMProvider):
reasoning_content = getattr(message, "reasoning_content", None) or None
thinking_blocks = getattr(message, "thinking_blocks", None) or None
return LLMResponse(
content=message.content,
content=content,
tool_calls=tool_calls,
finish_reason=choice.finish_reason or "stop",
finish_reason=finish_reason or "stop",
usage=usage,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,

View File

@@ -52,6 +52,9 @@ class OpenAICodexProvider(LLMProvider):
"parallel_tool_calls": True,
}
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools:
body["tools"] = _convert_tools(tools)

View File

@@ -26,33 +26,33 @@ class ProviderSpec:
"""
# identity
name: str # config field name, e.g. "dashscope"
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
display_name: str = "" # shown in `nanobot status`
name: str # config field name, e.g. "dashscope"
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
display_name: str = "" # shown in `nanobot status`
# model prefixing
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
env_extras: tuple[tuple[str, str], ...] = ()
# gateway / local detection
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
is_local: bool = False # local deployment (vLLM, Ollama)
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
detect_by_base_keyword: str = "" # match substring in api_base URL
default_api_base: str = "" # fallback base URL
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
is_local: bool = False # local deployment (vLLM, Ollama)
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
detect_by_base_keyword: str = "" # match substring in api_base URL
default_api_base: str = "" # fallback base URL
# gateway behavior
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
is_oauth: bool = False # if True, uses OAuth flow instead of API key
is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False
@@ -70,7 +70,6 @@ class ProviderSpec:
# ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = (
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
ProviderSpec(
name="custom",
@@ -81,16 +80,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
is_direct=True,
),
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
ProviderSpec(
name="azure_openai",
keywords=("azure", "azure-openai"),
env_key="",
display_name="Azure OpenAI",
litellm_prefix="",
is_direct=True,
),
# === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -102,16 +109,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
# AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
env_key="OPENAI_API_KEY", # OpenAI-compatible
env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix",
litellm_prefix="openai", # → openai/{model}
litellm_prefix="openai", # → openai/{model}
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -119,10 +125,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_key_prefix="",
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(),
),
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec(
name="siliconflow",
@@ -140,7 +145,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# VolcEngine (火山引擎): OpenAI-compatible gateway
ProviderSpec(
name="volcengine",
@@ -158,9 +162,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(
name="anthropic",
@@ -179,7 +181,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec(
name="openai",
@@ -197,14 +198,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# OpenAI Codex: uses OAuth, not API key.
ProviderSpec(
name="openai_codex",
keywords=("openai-codex",),
env_key="", # OAuth-based, no API key
env_key="", # OAuth-based, no API key
display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM
litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(),
env_extras=(),
is_gateway=False,
@@ -214,16 +214,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False,
model_overrides=(),
is_oauth=True, # OAuth-based authentication
is_oauth=True, # OAuth-based authentication
),
# Github Copilot: uses OAuth, not API key.
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
env_key="", # OAuth-based, no API key
env_key="", # OAuth-based, no API key
display_name="Github Copilot",
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
skip_prefixes=("github_copilot/",),
env_extras=(),
is_gateway=False,
@@ -233,17 +232,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
is_oauth=True, # OAuth-based authentication
is_oauth=True, # OAuth-based authentication
),
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
skip_prefixes=("deepseek/",), # avoid double-prefix
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
skip_prefixes=("deepseek/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -253,15 +251,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -271,7 +268,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -280,11 +276,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
litellm_prefix="zai", # glm-4 → zai/glm-4
litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=(
("ZHIPUAI_API_KEY", "{api_key}"),
),
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
@@ -293,14 +287,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -311,7 +304,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0.
@@ -320,22 +312,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"),
env_extras=(
("MOONSHOT_API_BASE", "{api_base}"),
),
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False,
model_overrides=(
("kimi-k2.5", {"temperature": 1.0}),
),
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec(
@@ -343,7 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -354,9 +341,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec(
@@ -364,20 +349,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="", # user must provide in config
default_api_base="", # user must provide in config
strip_model_prefix=False,
model_overrides=(),
),
# === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
ProviderSpec(
@@ -385,8 +368,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -403,6 +386,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers
# ---------------------------------------------------------------------------
def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local — those are matched by api_key/api_base instead."""
@@ -418,7 +402,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
return spec
for spec in std_specs:
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
if any(
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
):
return spec
return None

View File

@@ -5,6 +5,19 @@ from datetime import datetime
from pathlib import Path
def detect_image_mime(data: bytes) -> str | None:
"""Detect image MIME type from magic bytes, ignoring file extension."""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:3] == b"\xff\xd8\xff":
return "image/jpeg"
if data[:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
return None
def ensure_dir(path: Path) -> Path:
"""Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True)
@@ -34,6 +47,38 @@ def safe_filename(name: str) -> str:
return _UNSAFE_CHARS.sub("_", name).strip()
def split_message(content: str, max_len: int = 2000) -> list[str]:
"""
Split content into chunks within max_len, preferring line breaks.
Args:
content: The text content to split.
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
Returns:
List of message chunks, each within max_len.
"""
if not content:
return []
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
# Try to break at newline first, then space, then hard break
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files