diff --git a/.gitignore b/.gitignore index d7b930d..374875a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.worktrees/ .assets .env *.pyc @@ -19,4 +20,4 @@ __pycache__/ poetry.lock .pytest_cache/ botpy.log -tests/ + diff --git a/README.md b/README.md index 33cdeee..03f042a 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,11 @@
-π **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw) +π **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw). -β‘οΈ Delivers core agent functionality in just **~4,000** lines of code β **99% smaller** than Clawdbot's 430k+ lines. +β‘οΈ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw. -π Real-time line count: **3,935 lines** (run `bash core_agent_lines.sh` to verify anytime) +π Real-time line count: run `bash core_agent_lines.sh` to verify anytime. ## π’ News @@ -293,12 +293,18 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allowFrom": ["YOUR_USER_ID"], + "groupPolicy": "mention" } } } ``` +> `groupPolicy` controls how the bot responds in group channels: +> - `"mention"` (default) β Only respond when @mentioned +> - `"open"` β Respond to all messages +> DMs always respond when the sender is in `allowFrom`. + **5. Invite the bot** - OAuth2 β URL Generator - Scopes: `bot` @@ -414,6 +420,10 @@ nanobot channels login nanobot gateway ``` +> WhatsApp bridge updates are not applied automatically for existing installations. +> If you upgrade nanobot and need the latest WhatsApp bridge, run: +> `rm -rf ~/.nanobot/bridge && nanobot channels login` + display."""
+
+ def dw(s: str) -> int:
+ return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
+
+ rows: list[list[str]] = []
+ has_sep = False
+ for line in table_lines:
+ cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
+ if all(re.match(r'^:?-+:?$', c) for c in cells if c):
+ has_sep = True
+ continue
+ rows.append(cells)
+ if not rows or not has_sep:
+ return '\n'.join(table_lines)
+
+ ncols = max(len(r) for r in rows)
+ for r in rows:
+ r.extend([''] * (ncols - len(r)))
+ widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
+
+ def dr(cells: list[str]) -> str:
+ return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
+
+ out = [dr(rows[0])]
+ out.append(' '.join('β' * w for w in widths))
+ for row in rows[1:]:
+ out.append(dr(row))
+ return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str:
@@ -31,6 +77,27 @@ def _markdown_to_telegram_html(text: str) -> str:
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
+ # 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
+ lines = text.split('\n')
+ rebuilt: list[str] = []
+ li = 0
+ while li < len(lines):
+ if re.match(r'^\s*\|.+\|', lines[li]):
+ tbl: list[str] = []
+ while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
+ tbl.append(lines[li])
+ li += 1
+ box = _render_table_box(tbl)
+ if box != '\n'.join(tbl):
+ code_blocks.append(box)
+ rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
+ else:
+ rebuilt.extend(tbl)
+ else:
+ rebuilt.append(lines[li])
+ li += 1
+ text = '\n'.join(rebuilt)
+
# 2. Extract and protect inline code
inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str:
@@ -79,26 +146,6 @@ def _markdown_to_telegram_html(text: str) -> str:
return text
-def _split_message(content: str, max_len: int = 4000) -> list[str]:
- """Split content into chunks within max_len, preferring line breaks."""
- if len(content) <= max_len:
- return [content]
- chunks: list[str] = []
- while content:
- if len(content) <= max_len:
- chunks.append(content)
- break
- cut = content[:max_len]
- pos = cut.rfind('\n')
- if pos == -1:
- pos = cut.rfind(' ')
- if pos == -1:
- pos = max_len
- chunks.append(content[:pos])
- content = content[pos:].lstrip()
- return chunks
-
-
class TelegramChannel(BaseChannel):
"""
Telegram channel using long polling.
@@ -154,6 +201,7 @@ class TelegramChannel(BaseChannel):
# Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command))
+ self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents
@@ -229,7 +277,9 @@ class TelegramChannel(BaseChannel):
logger.warning("Telegram bot not running")
return
- self._stop_typing(msg.chat_id)
+ # Only stop typing indicator for final responses
+ if not msg.metadata.get("_progress", False):
+ self._stop_typing(msg.chat_id)
try:
chat_id = int(msg.chat_id)
@@ -273,25 +323,49 @@ class TelegramChannel(BaseChannel):
# Send text content
if msg.content and msg.content != "[empty message]":
- for chunk in _split_message(msg.content):
- try:
- html = _markdown_to_telegram_html(chunk)
- await self._app.bot.send_message(
- chat_id=chat_id,
- text=html,
- parse_mode="HTML",
- reply_parameters=reply_params
- )
- except Exception as e:
- logger.warning("HTML parse failed, falling back to plain text: {}", e)
- try:
- await self._app.bot.send_message(
- chat_id=chat_id,
- text=chunk,
- reply_parameters=reply_params
- )
- except Exception as e2:
- logger.error("Error sending Telegram message: {}", e2)
+ is_progress = msg.metadata.get("_progress", False)
+
+ for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
+ # Final response: simulate streaming via draft, then persist
+ if not is_progress:
+ await self._send_with_streaming(chat_id, chunk, reply_params)
+ else:
+ await self._send_text(chat_id, chunk, reply_params)
+
+ async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None:
+ """Send a plain text message with HTML fallback."""
+ try:
+ html = _markdown_to_telegram_html(text)
+ await self._app.bot.send_message(
+ chat_id=chat_id, text=html, parse_mode="HTML",
+ reply_parameters=reply_params,
+ )
+ except Exception as e:
+ logger.warning("HTML parse failed, falling back to plain text: {}", e)
+ try:
+ await self._app.bot.send_message(
+ chat_id=chat_id, text=text, reply_parameters=reply_params,
+ )
+ except Exception as e2:
+ logger.error("Error sending Telegram message: {}", e2)
+
+ async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None:
+ """Simulate streaming via send_message_draft, then persist with send_message."""
+ draft_id = int(time.time() * 1000) % (2**31)
+ try:
+ step = max(len(text) // 8, 40)
+ for i in range(step, len(text), step):
+ await self._app.bot.send_message_draft(
+ chat_id=chat_id, draft_id=draft_id, text=text[:i],
+ )
+ await asyncio.sleep(0.04)
+ await self._app.bot.send_message_draft(
+ chat_id=chat_id, draft_id=draft_id, text=text,
+ )
+ await asyncio.sleep(0.15)
+ except Exception:
+ pass
+ await self._send_text(chat_id, text, reply_params)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command."""
diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py
index 0d1ec7e..1307716 100644
--- a/nanobot/channels/whatsapp.py
+++ b/nanobot/channels/whatsapp.py
@@ -2,6 +2,7 @@
import asyncio
import json
+import mimetypes
from collections import OrderedDict
from loguru import logger
@@ -128,10 +129,22 @@ class WhatsAppChannel(BaseChannel):
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]"
+ # Extract media paths (images/documents/videos downloaded by the bridge)
+ media_paths = data.get("media") or []
+
+ # Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
+ if media_paths:
+ for p in media_paths:
+ mime, _ = mimetypes.guess_type(p)
+ media_type = "image" if mime and mime.startswith("image/") else "file"
+ media_tag = f"[{media_type}: {p}]"
+ content = f"{content}\n{media_tag}" if content else media_tag
+
await self._handle_message(
sender_id=sender_id,
chat_id=sender, # Use full LID for replies
content=content,
+ media=media_paths,
metadata={
"message_id": message_id,
"timestamp": data.get("timestamp"),
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py
index b75a2bc..ca5d8d7 100644
--- a/nanobot/cli/commands.py
+++ b/nanobot/cli/commands.py
@@ -7,6 +7,18 @@ import signal
import sys
from pathlib import Path
+# Force UTF-8 encoding for Windows console
+if sys.platform == "win32":
+ import locale
+ if sys.stdout.encoding != "utf-8":
+ os.environ["PYTHONIOENCODING"] = "utf-8"
+ # Re-open stdout/stderr with UTF-8 encoding
+ try:
+ sys.stdout.reconfigure(encoding="utf-8", errors="replace")
+ sys.stderr.reconfigure(encoding="utf-8", errors="replace")
+ except Exception:
+ pass
+
import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML
@@ -200,9 +212,8 @@ def onboard():
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
- from nanobot.providers.custom_provider import CustomProvider
- from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
+ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
@@ -213,6 +224,7 @@ def _make_provider(config: Config):
return OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
+ from nanobot.providers.custom_provider import CustomProvider
if provider_name == "custom":
return CustomProvider(
api_key=p.api_key if p else "no-key",
@@ -220,6 +232,21 @@ def _make_provider(config: Config):
default_model=model,
)
+ # Azure OpenAI: direct Azure OpenAI endpoint with deployment name
+ if provider_name == "azure_openai":
+ if not p or not p.api_key or not p.api_base:
+ console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
+ console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
+ console.print("Use the model field to specify the deployment name.")
+ raise typer.Exit(1)
+
+ return AzureOpenAIProvider(
+ api_key=p.api_key,
+ api_base=p.api_base,
+ default_model=model,
+ )
+
+ from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
@@ -244,13 +271,15 @@ def _make_provider(config: Config):
@app.command()
def gateway(
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
):
"""Start the nanobot gateway."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager
- from nanobot.config.loader import get_data_dir, load_config
+ from nanobot.config.loader import load_config
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
@@ -260,16 +289,20 @@ def gateway(
import logging
logging.basicConfig(level=logging.DEBUG)
- console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
+ config_path = Path(config) if config else None
+ config = load_config(config_path)
+ if workspace:
+ config.agents.defaults.workspace = workspace
- config = load_config()
+ console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation)
- cron_store_path = get_data_dir() / "cron" / "jobs.json"
+ # Use workspace path for per-instance cron store
+ cron_store_path = config.workspace_path / "cron" / "jobs.json"
cron = CronService(cron_store_path)
# Create agent with cron service
@@ -511,12 +544,21 @@ def agent(
else:
cli_channel, cli_chat_id = "cli", session_id
- def _exit_on_sigint(signum, frame):
+ def _handle_signal(signum, frame):
+ sig_name = signal.Signals(signum).name
_restore_terminal()
- console.print("\nGoodbye!")
- os._exit(0)
+ console.print(f"\nReceived {sig_name}, goodbye!")
+ sys.exit(0)
- signal.signal(signal.SIGINT, _exit_on_sigint)
+ signal.signal(signal.SIGINT, _handle_signal)
+ signal.signal(signal.SIGTERM, _handle_signal)
+ # SIGHUP is not available on Windows
+ if hasattr(signal, 'SIGHUP'):
+ signal.signal(signal.SIGHUP, _handle_signal)
+ # Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
+ # SIGPIPE is not available on Windows
+ if hasattr(signal, 'SIGPIPE'):
+ signal.signal(signal.SIGPIPE, signal.SIG_IGN)
async def run_interactive():
bus_task = asyncio.create_task(agent_loop.run())
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index 61a7bd2..803cb61 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -29,7 +29,9 @@ class TelegramConfig(Base):
enabled: bool = False
token: str = "" # Bot token from @BotFather
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
- proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
+ proxy: str | None = (
+ None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
+ )
reply_to_message: bool = False # If true, bot replies quote the original message
@@ -42,7 +44,9 @@ class FeishuConfig(Base):
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
verification_token: str = "" # Verification Token for event subscription (optional)
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
- react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
+ react_emoji: str = (
+ "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
+ )
class DingTalkConfig(Base):
@@ -62,6 +66,7 @@ class DiscordConfig(Base):
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
+ group_policy: Literal["mention", "open"] = "mention"
class MatrixConfig(Base):
@@ -72,9 +77,13 @@ class MatrixConfig(Base):
access_token: str = ""
user_id: str = "" # @bot:matrix.org
device_id: str = ""
- e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
- sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
- max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound).
+ e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
+ sync_stop_grace_seconds: int = (
+ 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
+ )
+ max_media_bytes: int = (
+ 20 * 1024 * 1024
+ ) # Max attachment size accepted for Matrix media handling (inbound + outbound).
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
@@ -105,7 +114,9 @@ class EmailConfig(Base):
from_address: str = ""
# Behavior
- auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent
+ auto_reply_enabled: bool = (
+ True # If false, inbound email is read but no automatic reply is sent
+ )
poll_interval_seconds: int = 30
mark_seen: bool = True
max_body_chars: int = 12000
@@ -183,27 +194,17 @@ class QQConfig(Base):
enabled: bool = False
app_id: str = "" # ζΊε¨δΊΊ ID (AppID) from q.qq.com
secret: str = "" # ζΊε¨δΊΊε―ι₯ (AppSecret) from q.qq.com
- allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
+ allow_from: list[str] = Field(
+ default_factory=list
+ ) # Allowed user openids (empty = public access)
+
+
-class MatrixConfig(Base):
- """Matrix (Element) channel configuration."""
- enabled: bool = False
- homeserver: str = "https://matrix.org"
- access_token: str = ""
- user_id: str = "" # e.g. @bot:matrix.org
- device_id: str = ""
- e2ee_enabled: bool = True # end-to-end encryption support
- sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
- max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
- allow_from: list[str] = Field(default_factory=list)
- group_policy: Literal["open", "mention", "allowlist"] = "open"
- group_allow_from: list[str] = Field(default_factory=list)
- allow_room_mentions: bool = False
class ChannelsConfig(Base):
"""Configuration for chat channels."""
- send_progress: bool = True # stream agent's text progress to the channel
+ send_progress: bool = True # stream agent's text progress to the channel
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("β¦"))
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
@@ -222,7 +223,9 @@ class AgentDefaults(Base):
workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5"
- provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
+ provider: str = (
+ "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
+ )
max_tokens: int = 8192
temperature: float = 0.1
max_tool_iterations: int = 40
@@ -248,6 +251,7 @@ class ProvidersConfig(Base):
"""Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
+ azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
@@ -260,8 +264,8 @@ class ProvidersConfig(Base):
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
- siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (η‘
εΊζ΅ε¨) API gateway
- volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (η«ε±±εΌζ) API gateway
+ siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (η‘
εΊζ΅ε¨)
+ volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (η«ε±±εΌζ)
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -291,7 +295,9 @@ class WebSearchConfig(Base):
class WebToolsConfig(Base):
"""Web tools configuration."""
- proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
+ proxy: str | None = (
+ None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
+ )
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
@@ -305,12 +311,13 @@ class ExecToolConfig(Base):
class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP)."""
+ type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
command: str = "" # Stdio: command to run (e.g. "npx")
args: list[str] = Field(default_factory=list) # Stdio: command arguments
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
- url: str = "" # HTTP: streamable HTTP endpoint URL
- headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
- tool_timeout: int = 30 # Seconds before a tool call is cancelled
+ url: str = "" # HTTP/SSE: endpoint URL
+ headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
+ tool_timeout: int = 30 # seconds before a tool call is cancelled
class ToolsConfig(Base):
@@ -336,7 +343,9 @@ class Config(BaseSettings):
"""Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser()
- def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
+ def _match_provider(
+ self, model: str | None = None
+ ) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS
diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py
index b2bb2b9..5bd06f9 100644
--- a/nanobot/providers/__init__.py
+++ b/nanobot/providers/__init__.py
@@ -3,5 +3,6 @@
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
+from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
-__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
+__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py
new file mode 100644
index 0000000..bd79b00
--- /dev/null
+++ b/nanobot/providers/azure_openai_provider.py
@@ -0,0 +1,210 @@
+"""Azure OpenAI provider implementation with API version 2024-10-21."""
+
+from __future__ import annotations
+
+import uuid
+from typing import Any
+from urllib.parse import urljoin
+
+import httpx
+import json_repair
+
+from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
+
+_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
+
+
+class AzureOpenAIProvider(LLMProvider):
+ """
+ Azure OpenAI provider with API version 2024-10-21 compliance.
+
+ Features:
+ - Hardcoded API version 2024-10-21
+ - Uses model field as Azure deployment name in URL path
+ - Uses api-key header instead of Authorization Bearer
+ - Uses max_completion_tokens instead of max_tokens
+ - Direct HTTP calls, bypasses LiteLLM
+ """
+
+ def __init__(
+ self,
+ api_key: str = "",
+ api_base: str = "",
+ default_model: str = "gpt-5.2-chat",
+ ):
+ super().__init__(api_key, api_base)
+ self.default_model = default_model
+ self.api_version = "2024-10-21"
+
+ # Validate required parameters
+ if not api_key:
+ raise ValueError("Azure OpenAI api_key is required")
+ if not api_base:
+ raise ValueError("Azure OpenAI api_base is required")
+
+ # Ensure api_base ends with /
+ if not api_base.endswith('/'):
+ api_base += '/'
+ self.api_base = api_base
+
+ def _build_chat_url(self, deployment_name: str) -> str:
+ """Build the Azure OpenAI chat completions URL."""
+ # Azure OpenAI URL format:
+ # https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
+ base_url = self.api_base
+ if not base_url.endswith('/'):
+ base_url += '/'
+
+ url = urljoin(
+ base_url,
+ f"openai/deployments/{deployment_name}/chat/completions"
+ )
+ return f"{url}?api-version={self.api_version}"
+
+ def _build_headers(self) -> dict[str, str]:
+ """Build headers for Azure OpenAI API with api-key header."""
+ return {
+ "Content-Type": "application/json",
+ "api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
+ "x-session-affinity": uuid.uuid4().hex, # For cache locality
+ }
+
+ @staticmethod
+ def _supports_temperature(
+ deployment_name: str,
+ reasoning_effort: str | None = None,
+ ) -> bool:
+ """Return True when temperature is likely supported for this deployment."""
+ if reasoning_effort:
+ return False
+ name = deployment_name.lower()
+ return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
+
+ def _prepare_request_payload(
+ self,
+ deployment_name: str,
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+ max_tokens: int = 4096,
+ temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ ) -> dict[str, Any]:
+ """Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
+ payload: dict[str, Any] = {
+ "messages": self._sanitize_request_messages(
+ self._sanitize_empty_content(messages),
+ _AZURE_MSG_KEYS,
+ ),
+ "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
+ }
+
+ if self._supports_temperature(deployment_name, reasoning_effort):
+ payload["temperature"] = temperature
+
+ if reasoning_effort:
+ payload["reasoning_effort"] = reasoning_effort
+
+ if tools:
+ payload["tools"] = tools
+ payload["tool_choice"] = "auto"
+
+ return payload
+
+ async def chat(
+ self,
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ max_tokens: int = 4096,
+ temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ ) -> LLMResponse:
+ """
+ Send a chat completion request to Azure OpenAI.
+
+ Args:
+ messages: List of message dicts with 'role' and 'content'.
+ tools: Optional list of tool definitions in OpenAI format.
+ model: Model identifier (used as deployment name).
+ max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
+ temperature: Sampling temperature.
+ reasoning_effort: Optional reasoning effort parameter.
+
+ Returns:
+ LLMResponse with content and/or tool calls.
+ """
+ deployment_name = model or self.default_model
+ url = self._build_chat_url(deployment_name)
+ headers = self._build_headers()
+ payload = self._prepare_request_payload(
+ deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
+ )
+
+ try:
+ async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
+ response = await client.post(url, headers=headers, json=payload)
+ if response.status_code != 200:
+ return LLMResponse(
+ content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
+ finish_reason="error",
+ )
+
+ response_data = response.json()
+ return self._parse_response(response_data)
+
+ except Exception as e:
+ return LLMResponse(
+ content=f"Error calling Azure OpenAI: {repr(e)}",
+ finish_reason="error",
+ )
+
+ def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
+ """Parse Azure OpenAI response into our standard format."""
+ try:
+ choice = response["choices"][0]
+ message = choice["message"]
+
+ tool_calls = []
+ if message.get("tool_calls"):
+ for tc in message["tool_calls"]:
+ # Parse arguments from JSON string if needed
+ args = tc["function"]["arguments"]
+ if isinstance(args, str):
+ args = json_repair.loads(args)
+
+ tool_calls.append(
+ ToolCallRequest(
+ id=tc["id"],
+ name=tc["function"]["name"],
+ arguments=args,
+ )
+ )
+
+ usage = {}
+ if response.get("usage"):
+ usage_data = response["usage"]
+ usage = {
+ "prompt_tokens": usage_data.get("prompt_tokens", 0),
+ "completion_tokens": usage_data.get("completion_tokens", 0),
+ "total_tokens": usage_data.get("total_tokens", 0),
+ }
+
+ reasoning_content = message.get("reasoning_content") or None
+
+ return LLMResponse(
+ content=message.get("content"),
+ tool_calls=tool_calls,
+ finish_reason=choice.get("finish_reason", "stop"),
+ usage=usage,
+ reasoning_content=reasoning_content,
+ )
+
+ except (KeyError, IndexError) as e:
+ return LLMResponse(
+ content=f"Error parsing Azure OpenAI response: {str(e)}",
+ finish_reason="error",
+ )
+
+ def get_default_model(self) -> str:
+ """Get the default model (also used as default deployment name)."""
+ return self.default_model
\ No newline at end of file
diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py
index 55bd805..0f73544 100644
--- a/nanobot/providers/base.py
+++ b/nanobot/providers/base.py
@@ -87,6 +87,20 @@ class LLMProvider(ABC):
result.append(msg)
return result
+ @staticmethod
+ def _sanitize_request_messages(
+ messages: list[dict[str, Any]],
+ allowed_keys: frozenset[str],
+ ) -> list[dict[str, Any]]:
+ """Keep only provider-safe message keys and normalize assistant content."""
+ sanitized = []
+ for msg in messages:
+ clean = {k: v for k, v in msg.items() if k in allowed_keys}
+ if clean.get("role") == "assistant" and "content" not in clean:
+ clean["content"] = None
+ sanitized.append(clean)
+ return sanitized
+
@abstractmethod
async def chat(
self,
diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py
index 56e6270..66df734 100644
--- a/nanobot/providers/custom_provider.py
+++ b/nanobot/providers/custom_provider.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import uuid
from typing import Any
import json_repair
@@ -15,7 +16,12 @@ class CustomProvider(LLMProvider):
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
super().__init__(api_key, api_base)
self.default_model = default_model
- self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
+ # Keep affinity stable for this provider instance to improve backend cache locality.
+ self._client = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ default_headers={"x-session-affinity": uuid.uuid4().hex},
+ )
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py
index d8d8ace..cb67635 100644
--- a/nanobot/providers/litellm_provider.py
+++ b/nanobot/providers/litellm_provider.py
@@ -1,5 +1,6 @@
"""LiteLLM provider implementation for multi-provider support."""
+import hashlib
import os
import secrets
import string
@@ -8,6 +9,7 @@ from typing import Any
import json_repair
import litellm
from litellm import acompletion
+from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway
@@ -165,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
return _ANTHROPIC_EXTRA_KEYS
return frozenset()
+ @staticmethod
+ def _normalize_tool_call_id(tool_call_id: Any) -> Any:
+ """Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
+ if not isinstance(tool_call_id, str):
+ return tool_call_id
+ if len(tool_call_id) == 9 and tool_call_id.isalnum():
+ return tool_call_id
+ return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
+
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys
- sanitized = []
- for msg in messages:
- clean = {k: v for k, v in msg.items() if k in allowed}
- # Strict providers require "content" even when assistant only has tool_calls
- if clean.get("role") == "assistant" and "content" not in clean:
- clean["content"] = None
- sanitized.append(clean)
+ sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
+ id_map: dict[str, str] = {}
+
+ def map_id(value: Any) -> Any:
+ if not isinstance(value, str):
+ return value
+ return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
+
+ for clean in sanitized:
+ # Keep assistant tool_calls[].id and tool tool_call_id in sync after
+ # shortening, otherwise strict providers reject the broken linkage.
+ if isinstance(clean.get("tool_calls"), list):
+ normalized_tool_calls = []
+ for tc in clean["tool_calls"]:
+ if not isinstance(tc, dict):
+ normalized_tool_calls.append(tc)
+ continue
+ tc_clean = dict(tc)
+ tc_clean["id"] = map_id(tc_clean.get("id"))
+ normalized_tool_calls.append(tc_clean)
+ clean["tool_calls"] = normalized_tool_calls
+
+ if "tool_call_id" in clean and clean["tool_call_id"]:
+ clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat(
@@ -255,20 +283,37 @@ class LiteLLMProvider(LLMProvider):
"""Parse LiteLLM response into our standard format."""
choice = response.choices[0]
message = choice.message
+ content = message.content
+ finish_reason = choice.finish_reason
+
+ # Some providers (e.g. GitHub Copilot) split content and tool_calls
+ # across multiple choices. Merge them so tool_calls are not lost.
+ raw_tool_calls = []
+ for ch in response.choices:
+ msg = ch.message
+ if hasattr(msg, "tool_calls") and msg.tool_calls:
+ raw_tool_calls.extend(msg.tool_calls)
+ if ch.finish_reason in ("tool_calls", "stop"):
+ finish_reason = ch.finish_reason
+ if not content and msg.content:
+ content = msg.content
+
+ if len(response.choices) > 1:
+ logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
+ len(response.choices), len(raw_tool_calls))
tool_calls = []
- if hasattr(message, "tool_calls") and message.tool_calls:
- for tc in message.tool_calls:
- # Parse arguments from JSON string if needed
- args = tc.function.arguments
- if isinstance(args, str):
- args = json_repair.loads(args)
+ for tc in raw_tool_calls:
+ # Parse arguments from JSON string if needed
+ args = tc.function.arguments
+ if isinstance(args, str):
+ args = json_repair.loads(args)
- tool_calls.append(ToolCallRequest(
- id=_short_tool_id(),
- name=tc.function.name,
- arguments=args,
- ))
+ tool_calls.append(ToolCallRequest(
+ id=_short_tool_id(),
+ name=tc.function.name,
+ arguments=args,
+ ))
usage = {}
if hasattr(response, "usage") and response.usage:
@@ -280,11 +325,11 @@ class LiteLLMProvider(LLMProvider):
reasoning_content = getattr(message, "reasoning_content", None) or None
thinking_blocks = getattr(message, "thinking_blocks", None) or None
-
+
return LLMResponse(
- content=message.content,
+ content=content,
tool_calls=tool_calls,
- finish_reason=choice.finish_reason or "stop",
+ finish_reason=finish_reason or "stop",
usage=usage,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py
index b6afa65..d04e210 100644
--- a/nanobot/providers/openai_codex_provider.py
+++ b/nanobot/providers/openai_codex_provider.py
@@ -52,6 +52,9 @@ class OpenAICodexProvider(LLMProvider):
"parallel_tool_calls": True,
}
+ if reasoning_effort:
+ body["reasoning"] = {"effort": reasoning_effort}
+
if tools:
body["tools"] = _convert_tools(tools)
diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py
index df915b7..3ba1a0e 100644
--- a/nanobot/providers/registry.py
+++ b/nanobot/providers/registry.py
@@ -26,33 +26,33 @@ class ProviderSpec:
"""
# identity
- name: str # config field name, e.g. "dashscope"
- keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
- env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
- display_name: str = "" # shown in `nanobot status`
+ name: str # config field name, e.g. "dashscope"
+ keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
+ env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
+ display_name: str = "" # shown in `nanobot status`
# model prefixing
- litellm_prefix: str = "" # "dashscope" β model becomes "dashscope/{model}"
- skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
+ litellm_prefix: str = "" # "dashscope" β model becomes "dashscope/{model}"
+ skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
env_extras: tuple[tuple[str, str], ...] = ()
# gateway / local detection
- is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
- is_local: bool = False # local deployment (vLLM, Ollama)
- detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
- detect_by_base_keyword: str = "" # match substring in api_base URL
- default_api_base: str = "" # fallback base URL
+ is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
+ is_local: bool = False # local deployment (vLLM, Ollama)
+ detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
+ detect_by_base_keyword: str = "" # match substring in api_base URL
+ default_api_base: str = "" # fallback base URL
# gateway behavior
- strip_model_prefix: bool = False # strip "provider/" before re-prefixing
+ strip_model_prefix: bool = False # strip "provider/" before re-prefixing
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
- is_oauth: bool = False # if True, uses OAuth flow instead of API key
+ is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False
@@ -70,7 +70,6 @@ class ProviderSpec:
# ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = (
-
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
ProviderSpec(
name="custom",
@@ -81,16 +80,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
is_direct=True,
),
+ # === Azure OpenAI (direct API calls with API version 2024-10-21) =====
+ ProviderSpec(
+ name="azure_openai",
+ keywords=("azure", "azure-openai"),
+ env_key="",
+ display_name="Azure OpenAI",
+ litellm_prefix="",
+ is_direct=True,
+ ),
# === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback.
-
# OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
- litellm_prefix="openrouter", # claude-3 β openrouter/claude-3
+ litellm_prefix="openrouter", # claude-3 β openrouter/claude-3
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -102,16 +109,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
-
# AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
- env_key="OPENAI_API_KEY", # OpenAI-compatible
+ env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix",
- litellm_prefix="openai", # β openai/{model}
+ litellm_prefix="openai", # β openai/{model}
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -119,10 +125,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_key_prefix="",
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
- strip_model_prefix=True, # anthropic/claude-3 β claude-3 β openai/claude-3
+ strip_model_prefix=True, # anthropic/claude-3 β claude-3 β openai/claude-3
model_overrides=(),
),
-
# SiliconFlow (η‘
εΊζ΅ε¨): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec(
name="siliconflow",
@@ -140,7 +145,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# VolcEngine (η«ε±±εΌζ): OpenAI-compatible gateway
ProviderSpec(
name="volcengine",
@@ -158,9 +162,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# === Standard providers (matched by model-name keywords) ===============
-
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(
name="anthropic",
@@ -179,7 +181,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
-
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec(
name="openai",
@@ -197,14 +198,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# OpenAI Codex: uses OAuth, not API key.
ProviderSpec(
name="openai_codex",
keywords=("openai-codex",),
- env_key="", # OAuth-based, no API key
+ env_key="", # OAuth-based, no API key
display_name="OpenAI Codex",
- litellm_prefix="", # Not routed through LiteLLM
+ litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(),
env_extras=(),
is_gateway=False,
@@ -214,16 +214,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False,
model_overrides=(),
- is_oauth=True, # OAuth-based authentication
+ is_oauth=True, # OAuth-based authentication
),
-
# Github Copilot: uses OAuth, not API key.
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
- env_key="", # OAuth-based, no API key
+ env_key="", # OAuth-based, no API key
display_name="Github Copilot",
- litellm_prefix="github_copilot", # github_copilot/model β github_copilot/model
+ litellm_prefix="github_copilot", # github_copilot/model β github_copilot/model
skip_prefixes=("github_copilot/",),
env_extras=(),
is_gateway=False,
@@ -233,17 +232,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
- is_oauth=True, # OAuth-based authentication
+ is_oauth=True, # OAuth-based authentication
),
-
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
- litellm_prefix="deepseek", # deepseek-chat β deepseek/deepseek-chat
- skip_prefixes=("deepseek/",), # avoid double-prefix
+ litellm_prefix="deepseek", # deepseek-chat β deepseek/deepseek-chat
+ skip_prefixes=("deepseek/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -253,15 +251,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
- litellm_prefix="gemini", # gemini-pro β gemini/gemini-pro
- skip_prefixes=("gemini/",), # avoid double-prefix
+ litellm_prefix="gemini", # gemini-pro β gemini/gemini-pro
+ skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -271,7 +268,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -280,11 +276,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
- litellm_prefix="zai", # glm-4 β zai/glm-4
+ litellm_prefix="zai", # glm-4 β zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
- env_extras=(
- ("ZHIPUAI_API_KEY", "{api_key}"),
- ),
+ env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
@@ -293,14 +287,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
- litellm_prefix="dashscope", # qwen-max β dashscope/qwen-max
+ litellm_prefix="dashscope", # qwen-max β dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -311,7 +304,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0.
@@ -320,22 +312,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
- litellm_prefix="moonshot", # kimi-k2.5 β moonshot/kimi-k2.5
+ litellm_prefix="moonshot", # kimi-k2.5 β moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"),
- env_extras=(
- ("MOONSHOT_API_BASE", "{api_base}"),
- ),
+ env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
- default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
+ default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False,
- model_overrides=(
- ("kimi-k2.5", {"temperature": 1.0}),
- ),
+ model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
-
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec(
@@ -343,7 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
- litellm_prefix="minimax", # MiniMax-M2.1 β minimax/MiniMax-M2.1
+ litellm_prefix="minimax", # MiniMax-M2.1 β minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -354,9 +341,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# === Local deployment (matched by config key, NOT by api_base) =========
-
# vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec(
@@ -364,20 +349,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
- litellm_prefix="hosted_vllm", # Llama-3-8B β hosted_vllm/Llama-3-8B
+ litellm_prefix="hosted_vllm", # Llama-3-8B β hosted_vllm/Llama-3-8B
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="",
- default_api_base="", # user must provide in config
+ default_api_base="", # user must provide in config
strip_model_prefix=False,
model_overrides=(),
),
-
# === Auxiliary (not a primary LLM provider) ============================
-
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last β it rarely wins fallback.
ProviderSpec(
@@ -385,8 +368,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
- litellm_prefix="groq", # llama3-8b-8192 β groq/llama3-8b-8192
- skip_prefixes=("groq/",), # avoid double-prefix
+ litellm_prefix="groq", # llama3-8b-8192 β groq/llama3-8b-8192
+ skip_prefixes=("groq/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -403,6 +386,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers
# ---------------------------------------------------------------------------
+
def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local β those are matched by api_key/api_base instead."""
@@ -418,7 +402,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
return spec
for spec in std_specs:
- if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
+ if any(
+ kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
+ ):
return spec
return None
diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py
index 3a8c802..c57c365 100644
--- a/nanobot/utils/helpers.py
+++ b/nanobot/utils/helpers.py
@@ -5,6 +5,19 @@ from datetime import datetime
from pathlib import Path
+def detect_image_mime(data: bytes) -> str | None:
+ """Detect image MIME type from magic bytes, ignoring file extension."""
+ if data[:8] == b"\x89PNG\r\n\x1a\n":
+ return "image/png"
+ if data[:3] == b"\xff\xd8\xff":
+ return "image/jpeg"
+ if data[:6] in (b"GIF87a", b"GIF89a"):
+ return "image/gif"
+ if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
+ return "image/webp"
+ return None
+
+
def ensure_dir(path: Path) -> Path:
"""Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True)
@@ -34,6 +47,38 @@ def safe_filename(name: str) -> str:
return _UNSAFE_CHARS.sub("_", name).strip()
+def split_message(content: str, max_len: int = 2000) -> list[str]:
+ """
+ Split content into chunks within max_len, preferring line breaks.
+
+ Args:
+ content: The text content to split.
+ max_len: Maximum length per chunk (default 2000 for Discord compatibility).
+
+ Returns:
+ List of message chunks, each within max_len.
+ """
+ if not content:
+ return []
+ if len(content) <= max_len:
+ return [content]
+ chunks: list[str] = []
+ while content:
+ if len(content) <= max_len:
+ chunks.append(content)
+ break
+ cut = content[:max_len]
+ # Try to break at newline first, then space, then hard break
+ pos = cut.rfind('\n')
+ if pos <= 0:
+ pos = cut.rfind(' ')
+ if pos <= 0:
+ pos = max_len
+ chunks.append(content[:pos])
+ content = content[pos:].lstrip()
+ return chunks
+
+
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files
diff --git a/pyproject.toml b/pyproject.toml
index 4199af1..41d0fbb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,7 @@ dependencies = [
"rich>=14.0.0,<15.0.0",
"croniter>=6.0.0,<7.0.0",
"dingtalk-stream>=0.24.0,<1.0.0",
- "python-telegram-bot[socks]>=22.0,<23.0",
+ "python-telegram-bot[socks]>=22.6,<23.0",
"lark-oapi>=1.5.0,<2.0.0",
"socksio>=1.0.0,<2.0.0",
"python-socketio>=5.16.0,<6.0.0",
@@ -42,6 +42,7 @@ dependencies = [
"prompt-toolkit>=3.0.50,<4.0.0",
"mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0",
+ "chardet>=3.0.2,<6.0.0",
"openai>=2.8.0",
]
@@ -55,6 +56,9 @@ dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"ruff>=0.1.0",
+ "matrix-nio[e2e]>=0.25.2",
+ "mistune>=3.0.0,<4.0.0",
+ "nh3>=0.2.17,<1.0.0",
]
[project.scripts]
diff --git a/tests/test_azure_openai_provider.py b/tests/test_azure_openai_provider.py
new file mode 100644
index 0000000..77f36d4
--- /dev/null
+++ b/tests/test_azure_openai_provider.py
@@ -0,0 +1,399 @@
+"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
+
+from unittest.mock import AsyncMock, Mock, patch
+
+import pytest
+
+from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
+from nanobot.providers.base import LLMResponse
+
+
+def test_azure_openai_provider_init():
+ """Test AzureOpenAIProvider initialization without deployment_name."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o-deployment",
+ )
+
+ assert provider.api_key == "test-key"
+ assert provider.api_base == "https://test-resource.openai.azure.com/"
+ assert provider.default_model == "gpt-4o-deployment"
+ assert provider.api_version == "2024-10-21"
+
+
+def test_azure_openai_provider_init_validation():
+ """Test AzureOpenAIProvider initialization validation."""
+ # Missing api_key
+ with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
+ AzureOpenAIProvider(api_key="", api_base="https://test.com")
+
+ # Missing api_base
+ with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
+ AzureOpenAIProvider(api_key="test", api_base="")
+
+
+def test_build_chat_url():
+ """Test Azure OpenAI URL building with different deployment names."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ # Test various deployment names
+ test_cases = [
+ ("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
+ ("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
+ ("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
+ ]
+
+ for deployment_name, expected_url in test_cases:
+ url = provider._build_chat_url(deployment_name)
+ assert url == expected_url
+
+
+def test_build_chat_url_api_base_without_slash():
+ """Test URL building when api_base doesn't end with slash."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com", # No trailing slash
+ default_model="gpt-4o",
+ )
+
+ url = provider._build_chat_url("test-deployment")
+ expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
+ assert url == expected
+
+
+def test_build_headers():
+ """Test Azure OpenAI header building with api-key authentication."""
+ provider = AzureOpenAIProvider(
+ api_key="test-api-key-123",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ headers = provider._build_headers()
+ assert headers["Content-Type"] == "application/json"
+ assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
+ assert "x-session-affinity" in headers
+
+
+def test_prepare_request_payload():
+ """Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ messages = [{"role": "user", "content": "Hello"}]
+ payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
+
+ assert payload["messages"] == messages
+ assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
+ assert payload["temperature"] == 0.8
+ assert "tools" not in payload
+
+ # Test with tools
+ tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
+ payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
+ assert payload_with_tools["tools"] == tools
+ assert payload_with_tools["tool_choice"] == "auto"
+
+ # Test with reasoning_effort
+ payload_with_reasoning = provider._prepare_request_payload(
+ "gpt-5-chat", messages, reasoning_effort="medium"
+ )
+ assert payload_with_reasoning["reasoning_effort"] == "medium"
+ assert "temperature" not in payload_with_reasoning
+
+
+def test_prepare_request_payload_sanitizes_messages():
+ """Test Azure payload strips non-standard message keys before sending."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ messages = [
+ {
+ "role": "assistant",
+ "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
+ "reasoning_content": "hidden chain-of-thought",
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_123",
+ "name": "x",
+ "content": "ok",
+ "extra_field": "should be removed",
+ },
+ ]
+
+ payload = provider._prepare_request_payload("gpt-4o", messages)
+
+ assert payload["messages"] == [
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_123",
+ "name": "x",
+ "content": "ok",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_chat_success():
+ """Test successful chat request using model as deployment name."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o-deployment",
+ )
+
+ # Mock response data
+ mock_response_data = {
+ "choices": [{
+ "message": {
+ "content": "Hello! How can I help you today?",
+ "role": "assistant"
+ },
+ "finish_reason": "stop"
+ }],
+ "usage": {
+ "prompt_tokens": 12,
+ "completion_tokens": 18,
+ "total_tokens": 30
+ }
+ }
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.json = Mock(return_value=mock_response_data)
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ # Test with specific model (deployment name)
+ messages = [{"role": "user", "content": "Hello"}]
+ result = await provider.chat(messages, model="custom-deployment")
+
+ assert isinstance(result, LLMResponse)
+ assert result.content == "Hello! How can I help you today?"
+ assert result.finish_reason == "stop"
+ assert result.usage["prompt_tokens"] == 12
+ assert result.usage["completion_tokens"] == 18
+ assert result.usage["total_tokens"] == 30
+
+ # Verify URL was built with the provided model as deployment name
+ call_args = mock_context.post.call_args
+ expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
+ assert call_args[0][0] == expected_url
+
+
+@pytest.mark.asyncio
+async def test_chat_uses_default_model_when_no_model_provided():
+ """Test that chat uses default_model when no model is specified."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="default-deployment",
+ )
+
+ mock_response_data = {
+ "choices": [{
+ "message": {"content": "Response", "role": "assistant"},
+ "finish_reason": "stop"
+ }],
+ "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
+ }
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.json = Mock(return_value=mock_response_data)
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "Test"}]
+ await provider.chat(messages) # No model specified
+
+ # Verify URL was built with default model as deployment name
+ call_args = mock_context.post.call_args
+ expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
+ assert call_args[0][0] == expected_url
+
+
+@pytest.mark.asyncio
+async def test_chat_with_tool_calls():
+ """Test chat request with tool calls in response."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ # Mock response with tool calls
+ mock_response_data = {
+ "choices": [{
+ "message": {
+ "content": None,
+ "role": "assistant",
+ "tool_calls": [{
+ "id": "call_12345",
+ "function": {
+ "name": "get_weather",
+ "arguments": '{"location": "San Francisco"}'
+ }
+ }]
+ },
+ "finish_reason": "tool_calls"
+ }],
+ "usage": {
+ "prompt_tokens": 20,
+ "completion_tokens": 15,
+ "total_tokens": 35
+ }
+ }
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.json = Mock(return_value=mock_response_data)
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "What's the weather?"}]
+ tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
+ result = await provider.chat(messages, tools=tools, model="weather-model")
+
+ assert isinstance(result, LLMResponse)
+ assert result.content is None
+ assert result.finish_reason == "tool_calls"
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].name == "get_weather"
+ assert result.tool_calls[0].arguments == {"location": "San Francisco"}
+
+
+@pytest.mark.asyncio
+async def test_chat_api_error():
+ """Test chat request API error handling."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 401
+ mock_response.text = "Invalid authentication credentials"
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "Hello"}]
+ result = await provider.chat(messages)
+
+ assert isinstance(result, LLMResponse)
+ assert "Azure OpenAI API Error 401" in result.content
+ assert "Invalid authentication credentials" in result.content
+ assert result.finish_reason == "error"
+
+
+@pytest.mark.asyncio
+async def test_chat_connection_error():
+ """Test chat request connection error handling."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "Hello"}]
+ result = await provider.chat(messages)
+
+ assert isinstance(result, LLMResponse)
+ assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
+ assert result.finish_reason == "error"
+
+
+def test_parse_response_malformed():
+ """Test response parsing with malformed data."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ # Test with missing choices
+ malformed_response = {"usage": {"prompt_tokens": 10}}
+ result = provider._parse_response(malformed_response)
+
+ assert isinstance(result, LLMResponse)
+ assert "Error parsing Azure OpenAI response" in result.content
+ assert result.finish_reason == "error"
+
+
+def test_get_default_model():
+ """Test get_default_model method."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="my-custom-deployment",
+ )
+
+ assert provider.get_default_model() == "my-custom-deployment"
+
+
+if __name__ == "__main__":
+ # Run basic tests
+ print("Running basic Azure OpenAI provider tests...")
+
+ # Test initialization
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o-deployment",
+ )
+ print("β
Provider initialization successful")
+
+ # Test URL building
+ url = provider._build_chat_url("my-deployment")
+ expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
+ assert url == expected
+ print("β
URL building works correctly")
+
+ # Test headers
+ headers = provider._build_headers()
+ assert headers["api-key"] == "test-key"
+ assert headers["Content-Type"] == "application/json"
+ print("β
Header building works correctly")
+
+ # Test payload preparation
+ messages = [{"role": "user", "content": "Test"}]
+ payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
+ assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
+ print("β
Payload preparation works correctly")
+
+ print("β
All basic tests passed! Updated test file is working correctly.")
\ No newline at end of file
diff --git a/tests/test_context_prompt_cache.py b/tests/test_context_prompt_cache.py
index 9afcc7d..ce796e2 100644
--- a/tests/test_context_prompt_cache.py
+++ b/tests/test_context_prompt_cache.py
@@ -40,7 +40,7 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
- """Runtime metadata should be a separate user message before the actual user message."""
+ """Runtime metadata should be merged with the user message."""
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
@@ -54,13 +54,12 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"]
- assert messages[-2]["role"] == "user"
- runtime_content = messages[-2]["content"]
- assert isinstance(runtime_content, str)
- assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
- assert "Current Time:" in runtime_content
- assert "Channel: cli" in runtime_content
- assert "Chat ID: direct" in runtime_content
-
+ # Runtime context is now merged with user message into a single message
assert messages[-1]["role"] == "user"
- assert messages[-1]["content"] == "Return exactly: OK"
+ user_content = messages[-1]["content"]
+ assert isinstance(user_content, str)
+ assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
+ assert "Current Time:" in user_content
+ assert "Channel: cli" in user_content
+ assert "Chat ID: direct" in user_content
+ assert "Return exactly: OK" in user_content
diff --git a/tests/test_cron_commands.py b/tests/test_cron_commands.py
deleted file mode 100644
index bce1ef5..0000000
--- a/tests/test_cron_commands.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from typer.testing import CliRunner
-
-from nanobot.cli.commands import app
-
-runner = CliRunner()
-
-
-def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
- monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
-
- result = runner.invoke(
- app,
- [
- "cron",
- "add",
- "--name",
- "demo",
- "--message",
- "hello",
- "--cron",
- "0 9 * * *",
- "--tz",
- "America/Vancovuer",
- ],
- )
-
- assert result.exit_code == 1
- assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
- assert not (tmp_path / "cron" / "jobs.json").exists()
diff --git a/tests/test_cron_service.py b/tests/test_cron_service.py
index 2a36f4c..9631da5 100644
--- a/tests/test_cron_service.py
+++ b/tests/test_cron_service.py
@@ -48,6 +48,8 @@ async def test_running_service_honors_external_disable(tmp_path) -> None:
)
await service.start()
try:
+ # Wait slightly to ensure file mtime is definitively different
+ await asyncio.sleep(0.05)
external = CronService(store_path)
updated = external.enable_job(job.id, enabled=False)
assert updated is not None
diff --git a/tests/test_feishu_post_content.py b/tests/test_feishu_post_content.py
index bf1ea82..7b1cb9d 100644
--- a/tests/test_feishu_post_content.py
+++ b/tests/test_feishu_post_content.py
@@ -1,4 +1,4 @@
-from nanobot.channels.feishu import _extract_post_content
+from nanobot.channels.feishu import FeishuChannel, _extract_post_content
def test_extract_post_content_supports_post_wrapper_shape() -> None:
@@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
assert text == "Daily report"
assert image_keys == ["img_a", "img_b"]
+
+
+def test_register_optional_event_keeps_builder_when_method_missing() -> None:
+ class Builder:
+ pass
+
+ builder = Builder()
+ same = FeishuChannel._register_optional_event(builder, "missing", object())
+ assert same is builder
+
+
+def test_register_optional_event_calls_supported_method() -> None:
+ called = []
+
+ class Builder:
+ def register_event(self, handler):
+ called.append(handler)
+ return self
+
+ builder = Builder()
+ handler = object()
+ same = FeishuChannel._register_optional_event(builder, "register_event", handler)
+
+ assert same is builder
+ assert called == [handler]
diff --git a/tests/test_feishu_table_split.py b/tests/test_feishu_table_split.py
new file mode 100644
index 0000000..af8fa16
--- /dev/null
+++ b/tests/test_feishu_table_split.py
@@ -0,0 +1,104 @@
+"""Tests for FeishuChannel._split_elements_by_table_limit.
+
+Feishu cards reject messages that contain more than one table element
+(API error 11310: card table number over limit). The helper splits a flat
+list of card elements into groups so that each group contains at most one
+table, allowing nanobot to send multiple cards instead of failing.
+"""
+
+from nanobot.channels.feishu import FeishuChannel
+
+
+def _md(text: str) -> dict:
+ return {"tag": "markdown", "content": text}
+
+
+def _table() -> dict:
+ return {
+ "tag": "table",
+ "columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
+ "rows": [{"c0": "v"}],
+ "page_size": 2,
+ }
+
+
+split = FeishuChannel._split_elements_by_table_limit
+
+
+def test_empty_list_returns_single_empty_group() -> None:
+ assert split([]) == [[]]
+
+
+def test_no_tables_returns_single_group() -> None:
+ els = [_md("hello"), _md("world")]
+ result = split(els)
+ assert result == [els]
+
+
+def test_single_table_stays_in_one_group() -> None:
+ els = [_md("intro"), _table(), _md("outro")]
+ result = split(els)
+ assert len(result) == 1
+ assert result[0] == els
+
+
+def test_two_tables_split_into_two_groups() -> None:
+ # Use different row values so the two tables are not equal
+ t1 = {
+ "tag": "table",
+ "columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
+ "rows": [{"c0": "table-one"}],
+ "page_size": 2,
+ }
+ t2 = {
+ "tag": "table",
+ "columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
+ "rows": [{"c0": "table-two"}],
+ "page_size": 2,
+ }
+ els = [_md("before"), t1, _md("between"), t2, _md("after")]
+ result = split(els)
+ assert len(result) == 2
+ # First group: text before table-1 + table-1
+ assert t1 in result[0]
+ assert t2 not in result[0]
+ # Second group: text between tables + table-2 + text after
+ assert t2 in result[1]
+ assert t1 not in result[1]
+
+
+def test_three_tables_split_into_three_groups() -> None:
+ tables = [
+ {"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
+ for i in range(3)
+ ]
+ els = tables[:]
+ result = split(els)
+ assert len(result) == 3
+ for i, group in enumerate(result):
+ assert tables[i] in group
+
+
+def test_leading_markdown_stays_with_first_table() -> None:
+ intro = _md("intro")
+ t = _table()
+ result = split([intro, t])
+ assert len(result) == 1
+ assert result[0] == [intro, t]
+
+
+def test_trailing_markdown_after_second_table() -> None:
+ t1, t2 = _table(), _table()
+ tail = _md("end")
+ result = split([t1, t2, tail])
+ assert len(result) == 2
+ assert result[1] == [t2, tail]
+
+
+def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
+ head = _md("head")
+ t1, t2 = _table(), _table()
+ result = split([head, t1, t2])
+ # head + t1 in group 0; t2 in group 1
+ assert result[0] == [head, t1]
+ assert result[1] == [t2]
diff --git a/tests/test_matrix_channel.py b/tests/test_matrix_channel.py
index c6714c2..c25b95a 100644
--- a/tests/test_matrix_channel.py
+++ b/tests/test_matrix_channel.py
@@ -159,6 +159,7 @@ class _FakeAsyncClient:
def _make_config(**kwargs) -> MatrixConfig:
+ kwargs.setdefault("allow_from", ["*"])
return MatrixConfig(
enabled=True,
homeserver="https://matrix.org",
@@ -274,7 +275,7 @@ async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
@pytest.mark.asyncio
-async def test_room_invite_joins_when_allow_list_is_empty() -> None:
+async def test_room_invite_ignores_when_allow_list_is_empty() -> None:
channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
@@ -284,9 +285,22 @@ async def test_room_invite_joins_when_allow_list_is_empty() -> None:
await channel._on_room_invite(room, event)
- assert client.join_calls == ["!room:matrix.org"]
+ assert client.join_calls == []
+@pytest.mark.asyncio
+async def test_room_invite_joins_when_sender_allowed() -> None:
+ channel = MatrixChannel(_make_config(allow_from=["@alice:matrix.org"]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org")
+ event = SimpleNamespace(sender="@alice:matrix.org")
+
+ await channel._on_room_invite(room, event)
+
+ assert client.join_calls == ["!room:matrix.org"]
+
@pytest.mark.asyncio
async def test_room_invite_respects_allow_list_when_configured() -> None:
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
@@ -1163,6 +1177,8 @@ async def test_send_progress_keeps_typing_keepalive_running() -> None:
assert "!room:matrix.org" in channel._typing_tasks
assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
+ await channel.stop()
+
@pytest.mark.asyncio
async def test_send_clears_typing_when_send_fails() -> None:
diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py
index 375c802..ff15584 100644
--- a/tests/test_memory_consolidation_types.py
+++ b/tests/test_memory_consolidation_types.py
@@ -145,3 +145,78 @@ class TestMemoryConsolidationTypeHandling:
assert result is True
provider.chat.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
+ """Some providers return arguments as a list - extract first element if it's a dict."""
+ store = MemoryStore(tmp_path)
+ provider = AsyncMock()
+
+ # Simulate arguments being a list containing a dict
+ response = LLMResponse(
+ content=None,
+ tool_calls=[
+ ToolCallRequest(
+ id="call_1",
+ name="save_memory",
+ arguments=[{
+ "history_entry": "[2026-01-01] User discussed testing.",
+ "memory_update": "# Memory\nUser likes testing.",
+ }],
+ )
+ ],
+ )
+ provider.chat = AsyncMock(return_value=response)
+ session = _make_session(message_count=60)
+
+ result = await store.consolidate(session, provider, "test-model", memory_window=50)
+
+ assert result is True
+ assert "User discussed testing." in store.history_file.read_text()
+ assert "User likes testing." in store.memory_file.read_text()
+
+ @pytest.mark.asyncio
+ async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
+ """Empty list arguments should return False."""
+ store = MemoryStore(tmp_path)
+ provider = AsyncMock()
+
+ response = LLMResponse(
+ content=None,
+ tool_calls=[
+ ToolCallRequest(
+ id="call_1",
+ name="save_memory",
+ arguments=[],
+ )
+ ],
+ )
+ provider.chat = AsyncMock(return_value=response)
+ session = _make_session(message_count=60)
+
+ result = await store.consolidate(session, provider, "test-model", memory_window=50)
+
+ assert result is False
+
+ @pytest.mark.asyncio
+ async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
+ """List with non-dict content should return False."""
+ store = MemoryStore(tmp_path)
+ provider = AsyncMock()
+
+ response = LLMResponse(
+ content=None,
+ tool_calls=[
+ ToolCallRequest(
+ id="call_1",
+ name="save_memory",
+ arguments=["string", "content"],
+ )
+ ],
+ )
+ provider.chat = AsyncMock(return_value=response)
+ session = _make_session(message_count=60)
+
+ result = await store.consolidate(session, provider, "test-model", memory_window=50)
+
+ assert result is False
diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py
index 26b8a16..f5e65c9 100644
--- a/tests/test_message_tool_suppress.py
+++ b/tests/test_message_tool_suppress.py
@@ -86,6 +86,36 @@ class TestMessageToolSuppressLogic:
assert result is not None
assert "Hello" in result.content
+ @pytest.mark.asyncio
+ async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
+ loop = _make_loop(tmp_path)
+ tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
+ calls = iter([
+ LLMResponse(
+ content="Visiblehidden ",
+ tool_calls=[tool_call],
+ reasoning_content="secret reasoning",
+ thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
+ ),
+ LLMResponse(content="Done", tool_calls=[]),
+ ])
+ loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
+ loop.tools.get_definitions = MagicMock(return_value=[])
+ loop.tools.execute = AsyncMock(return_value="ok")
+
+ progress: list[tuple[str, bool]] = []
+
+ async def on_progress(content: str, *, tool_hint: bool = False) -> None:
+ progress.append((content, tool_hint))
+
+ final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
+
+ assert final_content == "Done"
+ assert progress == [
+ ("Visible", False),
+ ('read_file("foo.txt")', True),
+ ]
+
class TestMessageToolTurnTracking:
diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py
index cb50fb0..c2b4b6a 100644
--- a/tests/test_tool_validation.py
+++ b/tests/test_tool_validation.py
@@ -106,3 +106,234 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths
+
+
+# --- cast_params tests ---
+
+
+class CastTestTool(Tool):
+ """Minimal tool for testing cast_params."""
+
+ def __init__(self, schema: dict[str, Any]) -> None:
+ self._schema = schema
+
+ @property
+ def name(self) -> str:
+ return "cast_test"
+
+ @property
+ def description(self) -> str:
+ return "test tool for casting"
+
+ @property
+ def parameters(self) -> dict[str, Any]:
+ return self._schema
+
+ async def execute(self, **kwargs: Any) -> str:
+ return "ok"
+
+
+def test_cast_params_string_to_int() -> None:
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"count": {"type": "integer"}},
+ }
+ )
+ result = tool.cast_params({"count": "42"})
+ assert result["count"] == 42
+ assert isinstance(result["count"], int)
+
+
+def test_cast_params_string_to_number() -> None:
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"rate": {"type": "number"}},
+ }
+ )
+ result = tool.cast_params({"rate": "3.14"})
+ assert result["rate"] == 3.14
+ assert isinstance(result["rate"], float)
+
+
+def test_cast_params_string_to_bool() -> None:
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"enabled": {"type": "boolean"}},
+ }
+ )
+ assert tool.cast_params({"enabled": "true"})["enabled"] is True
+ assert tool.cast_params({"enabled": "false"})["enabled"] is False
+ assert tool.cast_params({"enabled": "1"})["enabled"] is True
+
+
+def test_cast_params_array_items() -> None:
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {
+ "nums": {"type": "array", "items": {"type": "integer"}},
+ },
+ }
+ )
+ result = tool.cast_params({"nums": ["1", "2", "3"]})
+ assert result["nums"] == [1, 2, 3]
+
+
+def test_cast_params_nested_object() -> None:
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {
+ "config": {
+ "type": "object",
+ "properties": {
+ "port": {"type": "integer"},
+ "debug": {"type": "boolean"},
+ },
+ },
+ },
+ }
+ )
+ result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
+ assert result["config"]["port"] == 8080
+ assert result["config"]["debug"] is True
+
+
+def test_cast_params_bool_not_cast_to_int() -> None:
+ """Booleans should not be silently cast to integers."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"count": {"type": "integer"}},
+ }
+ )
+ result = tool.cast_params({"count": True})
+ assert result["count"] is True
+ errors = tool.validate_params(result)
+ assert any("count should be integer" in e for e in errors)
+
+
+def test_cast_params_preserves_empty_string() -> None:
+ """Empty strings should be preserved for string type."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"name": {"type": "string"}},
+ }
+ )
+ result = tool.cast_params({"name": ""})
+ assert result["name"] == ""
+
+
+def test_cast_params_bool_string_false() -> None:
+ """Test that 'false', '0', 'no' strings convert to False."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"flag": {"type": "boolean"}},
+ }
+ )
+ assert tool.cast_params({"flag": "false"})["flag"] is False
+ assert tool.cast_params({"flag": "False"})["flag"] is False
+ assert tool.cast_params({"flag": "0"})["flag"] is False
+ assert tool.cast_params({"flag": "no"})["flag"] is False
+ assert tool.cast_params({"flag": "NO"})["flag"] is False
+
+
+def test_cast_params_bool_string_invalid() -> None:
+ """Invalid boolean strings should not be cast."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"flag": {"type": "boolean"}},
+ }
+ )
+ # Invalid strings should be preserved (validation will catch them)
+ result = tool.cast_params({"flag": "random"})
+ assert result["flag"] == "random"
+ result = tool.cast_params({"flag": "maybe"})
+ assert result["flag"] == "maybe"
+
+
+def test_cast_params_invalid_string_to_int() -> None:
+ """Invalid strings should not be cast to integer."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"count": {"type": "integer"}},
+ }
+ )
+ result = tool.cast_params({"count": "abc"})
+ assert result["count"] == "abc" # Original value preserved
+ result = tool.cast_params({"count": "12.5.7"})
+ assert result["count"] == "12.5.7"
+
+
+def test_cast_params_invalid_string_to_number() -> None:
+ """Invalid strings should not be cast to number."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"rate": {"type": "number"}},
+ }
+ )
+ result = tool.cast_params({"rate": "not_a_number"})
+ assert result["rate"] == "not_a_number"
+
+
+def test_validate_params_bool_not_accepted_as_number() -> None:
+ """Booleans should not pass number validation."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"rate": {"type": "number"}},
+ }
+ )
+ errors = tool.validate_params({"rate": False})
+ assert any("rate should be number" in e for e in errors)
+
+
+def test_cast_params_none_values() -> None:
+ """Test None handling for different types."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "count": {"type": "integer"},
+ "items": {"type": "array"},
+ "config": {"type": "object"},
+ },
+ }
+ )
+ result = tool.cast_params(
+ {
+ "name": None,
+ "count": None,
+ "items": None,
+ "config": None,
+ }
+ )
+ # None should be preserved for all types
+ assert result["name"] is None
+ assert result["count"] is None
+ assert result["items"] is None
+ assert result["config"] is None
+
+
+def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
+ """Single values should NOT be automatically wrapped into arrays."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"items": {"type": "array"}},
+ }
+ )
+ # Non-array values should be preserved (validation will catch them)
+ result = tool.cast_params({"items": 5})
+ assert result["items"] == 5 # Not wrapped to [5]
+ result = tool.cast_params({"items": "text"})
+ assert result["items"] == "text" # Not wrapped to ["text"]