Merge branch 'main' of https://github.com/kunalk16/nanobot into feat-support-azure-openai

This commit is contained in:
Kunal Karmakar
2026-03-06 10:39:29 +00:00
12 changed files with 386 additions and 159 deletions

View File

@@ -664,6 +664,7 @@ Config file: `~/.nanobot/config.json`
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config. > - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key | | Provider | Purpose | Get API Key |
|----------|---------|-------------| |----------|---------|-------------|

View File

@@ -10,6 +10,7 @@ from typing import Any
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import detect_image_mime
class ContextBuilder: class ContextBuilder:
@@ -136,10 +137,14 @@ Reply directly with text for conversations. Only use the 'message' tool to send
images = [] images = []
for path in media: for path in media:
p = Path(path) p = Path(path)
mime, _ = mimetypes.guess_type(path) if not p.is_file():
if not p.is_file() or not mime or not mime.startswith("image/"):
continue continue
b64 = base64.b64encode(p.read_bytes()).decode() raw = p.read_bytes()
# Detect real MIME type from magic bytes; fallback to filename guess
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
if not mime or not mime.startswith("image/"):
continue
b64 = base64.b64encode(raw).decode()
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
if not images: if not images:

View File

@@ -128,6 +128,13 @@ class MemoryStore:
# Some providers return arguments as a JSON string instead of dict # Some providers return arguments as a JSON string instead of dict
if isinstance(args, str): if isinstance(args, str):
args = json.loads(args) args = json.loads(args)
# Some providers return arguments as a list (handle edge case)
if isinstance(args, list):
if args and isinstance(args[0], dict):
args = args[0]
else:
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
return False
if not isinstance(args, dict): if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False return False

View File

@@ -13,34 +13,13 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import DiscordConfig from nanobot.config.schema import DiscordConfig
from nanobot.utils.helpers import split_message
DISCORD_API_BASE = "https://discord.com/api/v10" DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit MAX_MESSAGE_LEN = 2000 # Discord message character limit
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
"""Split content into chunks within max_len, preferring line breaks."""
if not content:
return []
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
class DiscordChannel(BaseChannel): class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket.""" """Discord channel using Gateway websocket."""
@@ -105,7 +84,7 @@ class DiscordChannel(BaseChannel):
headers = {"Authorization": f"Bot {self.config.token}"} headers = {"Authorization": f"Bot {self.config.token}"}
try: try:
chunks = _split_message(msg.content or "") chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks: if not chunks:
return return

View File

@@ -472,8 +472,124 @@ class FeishuChannel(BaseChannel):
return elements or [{"tag": "markdown", "content": content}] return elements or [{"tag": "markdown", "content": content}]
# ── Smart format detection ──────────────────────────────────────────
# Patterns that indicate "complex" markdown needing card rendering
_COMPLEX_MD_RE = re.compile(
r"```" # fenced code block
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
r"|^#{1,6}\s+" # headings
, re.MULTILINE,
)
# Simple markdown patterns (bold, italic, strikethrough)
_SIMPLE_MD_RE = re.compile(
r"\*\*.+?\*\*" # **bold**
r"|__.+?__" # __bold__
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
r"|~~.+?~~" # ~~strikethrough~~
, re.DOTALL,
)
# Markdown link: [text](url)
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
# Unordered list items
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
# Ordered list items
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
# Max length for plain text format
_TEXT_MAX_LEN = 200
# Max length for post (rich text) format; beyond this, use card
_POST_MAX_LEN = 2000
@classmethod
def _detect_msg_format(cls, content: str) -> str:
"""Determine the optimal Feishu message format for *content*.
Returns one of:
- ``"text"`` plain text, short and no markdown
- ``"post"`` rich text (links only, moderate length)
- ``"interactive"`` card with full markdown rendering
"""
stripped = content.strip()
# Complex markdown (code blocks, tables, headings) → always card
if cls._COMPLEX_MD_RE.search(stripped):
return "interactive"
# Long content → card (better readability with card layout)
if len(stripped) > cls._POST_MAX_LEN:
return "interactive"
# Has bold/italic/strikethrough → card (post format can't render these)
if cls._SIMPLE_MD_RE.search(stripped):
return "interactive"
# Has list items → card (post format can't render list bullets well)
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
return "interactive"
# Has links → post format (supports <a> tags)
if cls._MD_LINK_RE.search(stripped):
return "post"
# Short plain text → text format
if len(stripped) <= cls._TEXT_MAX_LEN:
return "text"
# Medium plain text without any formatting → post format
return "post"
@classmethod
def _markdown_to_post(cls, content: str) -> str:
"""Convert markdown content to Feishu post message JSON.
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
Each line becomes a paragraph (row) in the post body.
"""
lines = content.strip().split("\n")
paragraphs: list[list[dict]] = []
for line in lines:
elements: list[dict] = []
last_end = 0
for m in cls._MD_LINK_RE.finditer(line):
# Text before this link
before = line[last_end:m.start()]
if before:
elements.append({"tag": "text", "text": before})
elements.append({
"tag": "a",
"text": m.group(1),
"href": m.group(2),
})
last_end = m.end()
# Remaining text after last link
remaining = line[last_end:]
if remaining:
elements.append({"tag": "text", "text": remaining})
# Empty line → empty paragraph for spacing
if not elements:
elements.append({"tag": "text", "text": ""})
paragraphs.append(elements)
post_body = {
"zh_cn": {
"content": paragraphs,
}
}
return json.dumps(post_body, ensure_ascii=False)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"} _IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
_AUDIO_EXTS = {".opus"} _AUDIO_EXTS = {".opus"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
_FILE_TYPE_MAP = { _FILE_TYPE_MAP = {
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc", ".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt", ".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
@@ -682,21 +798,46 @@ class FeishuChannel(BaseChannel):
else: else:
key = await loop.run_in_executor(None, self._upload_file_sync, file_path) key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
if key: if key:
media_type = "audio" if ext in self._AUDIO_EXTS else "file" # Use msg_type "media" for audio/video so users can play inline;
# "file" for everything else (documents, archives, etc.)
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
media_type = "media"
else:
media_type = "file"
await loop.run_in_executor( await loop.run_in_executor(
None, self._send_message_sync, None, self._send_message_sync,
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False), receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
) )
if msg.content and msg.content.strip(): if msg.content and msg.content.strip():
elements = self._build_card_elements(msg.content) fmt = self._detect_msg_format(msg.content)
for chunk in self._split_elements_by_table_limit(elements):
card = {"config": {"wide_screen_mode": True}, "elements": chunk} if fmt == "text":
# Short plain text send as simple text message
text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
await loop.run_in_executor( await loop.run_in_executor(
None, self._send_message_sync, None, self._send_message_sync,
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False), receive_id_type, msg.chat_id, "text", text_body,
) )
elif fmt == "post":
# Medium content with links send as rich-text post
post_body = self._markdown_to_post(msg.content)
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "post", post_body,
)
else:
# Complex / long content send as interactive card
elements = self._build_card_elements(msg.content)
for chunk in self._split_elements_by_table_limit(elements):
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
)
except Exception as e: except Exception as e:
logger.error("Error sending Feishu message: {}", e) logger.error("Error sending Feishu message: {}", e)

View File

@@ -14,6 +14,9 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import TelegramConfig from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
def _markdown_to_telegram_html(text: str) -> str: def _markdown_to_telegram_html(text: str) -> str:
@@ -79,26 +82,6 @@ def _markdown_to_telegram_html(text: str) -> str:
return text return text
def _split_message(content: str, max_len: int = 4000) -> list[str]:
"""Split content into chunks within max_len, preferring line breaks."""
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
pos = cut.rfind('\n')
if pos == -1:
pos = cut.rfind(' ')
if pos == -1:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
class TelegramChannel(BaseChannel): class TelegramChannel(BaseChannel):
""" """
Telegram channel using long polling. Telegram channel using long polling.
@@ -273,8 +256,8 @@ class TelegramChannel(BaseChannel):
if msg.content and msg.content != "[empty message]": if msg.content and msg.content != "[empty message]":
is_progress = msg.metadata.get("_progress", False) is_progress = msg.metadata.get("_progress", False)
draft_id = msg.metadata.get("message_id") draft_id = msg.metadata.get("message_id")
for chunk in _split_message(msg.content): for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
try: try:
html = _markdown_to_telegram_html(chunk) html = _markdown_to_telegram_html(chunk)
if is_progress and draft_id: if is_progress and draft_id:

View File

@@ -7,6 +7,18 @@ import signal
import sys import sys
from pathlib import Path from pathlib import Path
# Force UTF-8 encoding for Windows console
if sys.platform == "win32":
import locale
if sys.stdout.encoding != "utf-8":
os.environ["PYTHONIOENCODING"] = "utf-8"
# Re-open stdout/stderr with UTF-8 encoding
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
import typer import typer
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML from prompt_toolkit.formatted_text import HTML
@@ -200,8 +212,6 @@ def onboard():
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.custom_provider import CustomProvider
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
@@ -214,6 +224,7 @@ def _make_provider(config: Config):
return OpenAICodexProvider(default_model=model) return OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
from nanobot.providers.custom_provider import CustomProvider
if provider_name == "custom": if provider_name == "custom":
return CustomProvider( return CustomProvider(
api_key=p.api_key if p else "no-key", api_key=p.api_key if p else "no-key",
@@ -235,6 +246,7 @@ def _make_provider(config: Config):
default_model=model, default_model=model,
) )
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name) spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth): if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
@@ -540,9 +552,13 @@ def agent(
signal.signal(signal.SIGINT, _handle_signal) signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal) signal.signal(signal.SIGTERM, _handle_signal)
signal.signal(signal.SIGHUP, _handle_signal) # SIGHUP is not available on Windows
if hasattr(signal, 'SIGHUP'):
signal.signal(signal.SIGHUP, _handle_signal)
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes # Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
signal.signal(signal.SIGPIPE, signal.SIG_IGN) # SIGPIPE is not available on Windows
if hasattr(signal, 'SIGPIPE'):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
async def run_interactive(): async def run_interactive():
bus_task = asyncio.create_task(agent_loop.run()) bus_task = asyncio.create_task(agent_loop.run())

View File

@@ -199,21 +199,6 @@ class QQConfig(Base):
) # Allowed user openids (empty = public access) ) # Allowed user openids (empty = public access)
class MatrixConfig(Base):
"""Matrix (Element) channel configuration."""
enabled: bool = False
homeserver: str = "https://matrix.org"
access_token: str = ""
user_id: str = "" # e.g. @bot:matrix.org
device_id: str = ""
e2ee_enabled: bool = True # end-to-end encryption support
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
class ChannelsConfig(Base): class ChannelsConfig(Base):
@@ -279,12 +264,8 @@ class ProvidersConfig(Base):
moonshot: ProviderConfig = Field(default_factory=ProviderConfig) moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
siliconflow: ProviderConfig = Field( siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
default_factory=ProviderConfig volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
) # SiliconFlow (硅基流动) API gateway
volcengine: ProviderConfig = Field(
default_factory=ProviderConfig
) # VolcEngine (火山引擎) API gateway
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)

View File

@@ -8,6 +8,7 @@ from typing import Any
import json_repair import json_repair
import litellm import litellm
from litellm import acompletion from litellm import acompletion
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway from nanobot.providers.registry import find_by_model, find_gateway
@@ -255,20 +256,37 @@ class LiteLLMProvider(LLMProvider):
"""Parse LiteLLM response into our standard format.""" """Parse LiteLLM response into our standard format."""
choice = response.choices[0] choice = response.choices[0]
message = choice.message message = choice.message
content = message.content
finish_reason = choice.finish_reason
# Some providers (e.g. GitHub Copilot) split content and tool_calls
# across multiple choices. Merge them so tool_calls are not lost.
raw_tool_calls = []
for ch in response.choices:
msg = ch.message
if hasattr(msg, "tool_calls") and msg.tool_calls:
raw_tool_calls.extend(msg.tool_calls)
if ch.finish_reason in ("tool_calls", "stop"):
finish_reason = ch.finish_reason
if not content and msg.content:
content = msg.content
if len(response.choices) > 1:
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
len(response.choices), len(raw_tool_calls))
tool_calls = [] tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls: for tc in raw_tool_calls:
for tc in message.tool_calls: # Parse arguments from JSON string if needed
# Parse arguments from JSON string if needed args = tc.function.arguments
args = tc.function.arguments if isinstance(args, str):
if isinstance(args, str): args = json_repair.loads(args)
args = json_repair.loads(args)
tool_calls.append(ToolCallRequest( tool_calls.append(ToolCallRequest(
id=_short_tool_id(), id=_short_tool_id(),
name=tc.function.name, name=tc.function.name,
arguments=args, arguments=args,
)) ))
usage = {} usage = {}
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
@@ -280,11 +298,11 @@ class LiteLLMProvider(LLMProvider):
reasoning_content = getattr(message, "reasoning_content", None) or None reasoning_content = getattr(message, "reasoning_content", None) or None
thinking_blocks = getattr(message, "thinking_blocks", None) or None thinking_blocks = getattr(message, "thinking_blocks", None) or None
return LLMResponse( return LLMResponse(
content=message.content, content=content,
tool_calls=tool_calls, tool_calls=tool_calls,
finish_reason=choice.finish_reason or "stop", finish_reason=finish_reason or "stop",
usage=usage, usage=usage,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks, thinking_blocks=thinking_blocks,

View File

@@ -26,33 +26,33 @@ class ProviderSpec:
""" """
# identity # identity
name: str # config field name, e.g. "dashscope" name: str # config field name, e.g. "dashscope"
keywords: tuple[str, ...] # model-name keywords for matching (lowercase) keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY" env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
display_name: str = "" # shown in `nanobot status` display_name: str = "" # shown in `nanobot status`
# model prefixing # model prefixing
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}" litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
env_extras: tuple[tuple[str, str], ...] = () env_extras: tuple[tuple[str, str], ...] = ()
# gateway / local detection # gateway / local detection
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix) is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
is_local: bool = False # local deployment (vLLM, Ollama) is_local: bool = False # local deployment (vLLM, Ollama)
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-" detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
detect_by_base_keyword: str = "" # match substring in api_base URL detect_by_base_keyword: str = "" # match substring in api_base URL
default_api_base: str = "" # fallback base URL default_api_base: str = "" # fallback base URL
# gateway behavior # gateway behavior
strip_model_prefix: bool = False # strip "provider/" before re-prefixing strip_model_prefix: bool = False # strip "provider/" before re-prefixing
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),) # per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys # OAuth-based providers (e.g., OpenAI Codex) don't use API keys
is_oauth: bool = False # if True, uses OAuth flow instead of API key is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider) # Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False is_direct: bool = False
@@ -70,7 +70,6 @@ class ProviderSpec:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = ( PROVIDERS: tuple[ProviderSpec, ...] = (
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ====== # === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
ProviderSpec( ProviderSpec(
name="custom", name="custom",
@@ -90,17 +89,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
litellm_prefix="", litellm_prefix="",
is_direct=True, is_direct=True,
), ),
# === Gateways (detected by api_key / api_base, not model name) ========= # === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback. # Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-" # OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec( ProviderSpec(
name="openrouter", name="openrouter",
keywords=("openrouter",), keywords=("openrouter",),
env_key="OPENROUTER_API_KEY", env_key="OPENROUTER_API_KEY",
display_name="OpenRouter", display_name="OpenRouter",
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3 litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=True, is_gateway=True,
@@ -112,16 +109,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(), model_overrides=(),
supports_prompt_caching=True, supports_prompt_caching=True,
), ),
# AiHubMix: global gateway, OpenAI-compatible interface. # AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3", # strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3". # so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec( ProviderSpec(
name="aihubmix", name="aihubmix",
keywords=("aihubmix",), keywords=("aihubmix",),
env_key="OPENAI_API_KEY", # OpenAI-compatible env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix", display_name="AiHubMix",
litellm_prefix="openai", # → openai/{model} litellm_prefix="openai", # → openai/{model}
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=True, is_gateway=True,
@@ -129,10 +125,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="aihubmix", detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1", default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3 strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(), model_overrides=(),
), ),
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix # SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec( ProviderSpec(
name="siliconflow", name="siliconflow",
@@ -150,7 +145,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# VolcEngine (火山引擎): OpenAI-compatible gateway # VolcEngine (火山引擎): OpenAI-compatible gateway
ProviderSpec( ProviderSpec(
name="volcengine", name="volcengine",
@@ -168,9 +162,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Standard providers (matched by model-name keywords) =============== # === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(
name="anthropic", name="anthropic",
@@ -189,7 +181,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(), model_overrides=(),
supports_prompt_caching=True, supports_prompt_caching=True,
), ),
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(
name="openai", name="openai",
@@ -207,14 +198,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# OpenAI Codex: uses OAuth, not API key. # OpenAI Codex: uses OAuth, not API key.
ProviderSpec( ProviderSpec(
name="openai_codex", name="openai_codex",
keywords=("openai-codex",), keywords=("openai-codex",),
env_key="", # OAuth-based, no API key env_key="", # OAuth-based, no API key
display_name="OpenAI Codex", display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -224,16 +214,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://chatgpt.com/backend-api", default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
is_oauth=True, # OAuth-based authentication is_oauth=True, # OAuth-based authentication
), ),
# Github Copilot: uses OAuth, not API key. # Github Copilot: uses OAuth, not API key.
ProviderSpec( ProviderSpec(
name="github_copilot", name="github_copilot",
keywords=("github_copilot", "copilot"), keywords=("github_copilot", "copilot"),
env_key="", # OAuth-based, no API key env_key="", # OAuth-based, no API key
display_name="Github Copilot", display_name="Github Copilot",
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
skip_prefixes=("github_copilot/",), skip_prefixes=("github_copilot/",),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -243,17 +232,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="", default_api_base="",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
is_oauth=True, # OAuth-based authentication is_oauth=True, # OAuth-based authentication
), ),
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing. # DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec( ProviderSpec(
name="deepseek", name="deepseek",
keywords=("deepseek",), keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY", env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek", display_name="DeepSeek",
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
skip_prefixes=("deepseek/",), # avoid double-prefix skip_prefixes=("deepseek/",), # avoid double-prefix
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
@@ -263,15 +251,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Gemini: needs "gemini/" prefix for LiteLLM. # Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec( ProviderSpec(
name="gemini", name="gemini",
keywords=("gemini",), keywords=("gemini",),
env_key="GEMINI_API_KEY", env_key="GEMINI_API_KEY",
display_name="Gemini", display_name="Gemini",
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
@@ -281,7 +268,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Zhipu: LiteLLM uses "zai/" prefix. # Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway. # skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -290,11 +276,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("zhipu", "glm", "zai"), keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY", env_key="ZAI_API_KEY",
display_name="Zhipu AI", display_name="Zhipu AI",
litellm_prefix="zai", # glm-4 → zai/glm-4 litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=( env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
("ZHIPUAI_API_KEY", "{api_key}"),
),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
detect_by_key_prefix="", detect_by_key_prefix="",
@@ -303,14 +287,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# DashScope: Qwen models, needs "dashscope/" prefix. # DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec( ProviderSpec(
name="dashscope", name="dashscope",
keywords=("qwen", "dashscope"), keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY", env_key="DASHSCOPE_API_KEY",
display_name="DashScope", display_name="DashScope",
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"), skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -321,7 +304,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Moonshot: Kimi models, needs "moonshot/" prefix. # Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0. # Kimi K2.5 API enforces temperature >= 1.0.
@@ -330,22 +312,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("moonshot", "kimi"), keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY", env_key="MOONSHOT_API_KEY",
display_name="Moonshot", display_name="Moonshot",
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5 litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"), skip_prefixes=("moonshot/", "openrouter/"),
env_extras=( env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
("MOONSHOT_API_BASE", "{api_base}"),
),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="", detect_by_base_keyword="",
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=( model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
("kimi-k2.5", {"temperature": 1.0}),
),
), ),
# MiniMax: needs "minimax/" prefix for LiteLLM routing. # MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1. # Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec( ProviderSpec(
@@ -353,7 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("minimax",), keywords=("minimax",),
env_key="MINIMAX_API_KEY", env_key="MINIMAX_API_KEY",
display_name="MiniMax", display_name="MiniMax",
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1 litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"), skip_prefixes=("minimax/", "openrouter/"),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -364,9 +341,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Local deployment (matched by config key, NOT by api_base) ========= # === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server. # vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm"). # Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec( ProviderSpec(
@@ -374,20 +349,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("vllm",), keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY", env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local", display_name="vLLM/Local",
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=True, is_local=True,
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="", detect_by_base_keyword="",
default_api_base="", # user must provide in config default_api_base="", # user must provide in config
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Auxiliary (not a primary LLM provider) ============================ # === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM. # Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
ProviderSpec( ProviderSpec(
@@ -395,8 +368,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("groq",), keywords=("groq",),
env_key="GROQ_API_KEY", env_key="GROQ_API_KEY",
display_name="Groq", display_name="Groq",
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192 litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix skip_prefixes=("groq/",), # avoid double-prefix
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
@@ -413,6 +386,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers # Lookup helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def find_by_model(model: str) -> ProviderSpec | None: def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive). """Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local — those are matched by api_key/api_base instead.""" Skips gateways/local — those are matched by api_key/api_base instead."""
@@ -428,7 +402,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
return spec return spec
for spec in std_specs: for spec in std_specs:
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords): if any(
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
):
return spec return spec
return None return None

View File

@@ -5,6 +5,19 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
def detect_image_mime(data: bytes) -> str | None:
"""Detect image MIME type from magic bytes, ignoring file extension."""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:3] == b"\xff\xd8\xff":
return "image/jpeg"
if data[:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
return None
def ensure_dir(path: Path) -> Path: def ensure_dir(path: Path) -> Path:
"""Ensure directory exists, return it.""" """Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
@@ -34,6 +47,38 @@ def safe_filename(name: str) -> str:
return _UNSAFE_CHARS.sub("_", name).strip() return _UNSAFE_CHARS.sub("_", name).strip()
def split_message(content: str, max_len: int = 2000) -> list[str]:
"""
Split content into chunks within max_len, preferring line breaks.
Args:
content: The text content to split.
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
Returns:
List of message chunks, each within max_len.
"""
if not content:
return []
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
# Try to break at newline first, then space, then hard break
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files.""" """Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files from importlib.resources import files as pkg_files

View File

@@ -145,3 +145,78 @@ class TestMemoryConsolidationTypeHandling:
assert result is True assert result is True
provider.chat.assert_not_called() provider.chat.assert_not_called()
@pytest.mark.asyncio
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
"""Some providers return arguments as a list - extract first element if it's a dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[{
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}],
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert "User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
"""Empty list arguments should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[],
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
@pytest.mark.asyncio
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
"""List with non-dict content should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=["string", "content"],
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False