Merge branch 'main' into pr-2304

This commit is contained in:
Xubin Ren
2026-03-21 04:14:40 +00:00
26 changed files with 2536 additions and 143 deletions

View File

@@ -2,7 +2,7 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
# Install Node.js 20 for the WhatsApp bridge # Install Node.js 20 for the WhatsApp bridge
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \ apt-get install -y --no-install-recommends curl ca-certificates gnupg git openssh-client && \
mkdir -p /etc/apt/keyrings && \ mkdir -p /etc/apt/keyrings && \
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \ curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \ echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
@@ -26,6 +26,8 @@ COPY bridge/ bridge/
RUN uv pip install --system --no-cache . RUN uv pip install --system --no-cache .
# Build the WhatsApp bridge # Build the WhatsApp bridge
RUN git config --global url."https://github.com/".insteadOf "ssh://git@github.com/"
WORKDIR /app/bridge WORKDIR /app/bridge
RUN npm install && npm run build RUN npm install && npm run build
WORKDIR /app WORKDIR /app

View File

@@ -191,9 +191,11 @@ nanobot channels login
nanobot onboard nanobot onboard
``` ```
Use `nanobot onboard --wizard` if you want the interactive setup wizard.
**2. Configure** (`~/.nanobot/config.json`) **2. Configure** (`~/.nanobot/config.json`)
Add or merge these **two parts** into your config (other options have defaults). Configure these **two parts** in your config (other options have defaults).
*Set your API key* (e.g. OpenRouter, recommended for global users): *Set your API key* (e.g. OpenRouter, recommended for global users):
```json ```json
@@ -809,6 +811,7 @@ Config file: `~/.nanobot/config.json`
<summary><b>OpenAI Codex (OAuth)</b></summary> <summary><b>OpenAI Codex (OAuth)</b></summary>
Codex uses OAuth instead of API keys. Requires a ChatGPT Plus or Pro account. Codex uses OAuth instead of API keys. Requires a ChatGPT Plus or Pro account.
No `providers.openaiCodex` block is needed in `config.json`; `nanobot provider login` stores the OAuth session outside config.
**1. Login:** **1. Login:**
```bash ```bash
@@ -841,6 +844,44 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
</details> </details>
<details>
<summary><b>GitHub Copilot (OAuth)</b></summary>
GitHub Copilot uses OAuth instead of API keys. Requires a [GitHub account with a plan](https://github.com/features/copilot/plans) configured.
No `providers.githubCopilot` block is needed in `config.json`; `nanobot provider login` stores the OAuth session outside config.
**1. Login:**
```bash
nanobot provider login github-copilot
```
**2. Set model** (merge into `~/.nanobot/config.json`):
```json
{
"agents": {
"defaults": {
"model": "github-copilot/gpt-4.1"
}
}
}
```
**3. Chat:**
```bash
nanobot agent -m "Hello!"
# Target a specific workspace/config locally
nanobot agent -c ~/.nanobot-telegram/config.json -m "Hello!"
# One-off workspace override on top of that config
nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -m "Hello!"
```
> Docker users: use `docker run -it` for interactive OAuth login.
</details>
<details> <details>
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary> <summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
@@ -1161,6 +1202,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
| Option | Default | Description | | Option | Default | Description |
|--------|---------|-------------| |--------|---------|-------------|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. | | `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
| `tools.exec.enable` | `true` | When `false`, the shell `exec` tool is not registered at all. Use this to completely disable shell command execution. |
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). | | `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
| `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. | | `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. |
@@ -1288,6 +1330,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
| Command | Description | | Command | Description |
|---------|-------------| |---------|-------------|
| `nanobot onboard` | Initialize config & workspace at `~/.nanobot/` | | `nanobot onboard` | Initialize config & workspace at `~/.nanobot/` |
| `nanobot onboard --wizard` | Launch the interactive onboarding wizard |
| `nanobot onboard -c <config> -w <workspace>` | Initialize or refresh a specific instance config and workspace | | `nanobot onboard -c <config> -w <workspace>` | Initialize or refresh a specific instance config and workspace |
| `nanobot agent -m "..."` | Chat with the agent | | `nanobot agent -m "..."` | Chat with the agent |
| `nanobot agent -w <workspace>` | Chat against a specific workspace | | `nanobot agent -w <workspace>` | Chat against a specific workspace |

View File

@@ -120,6 +120,7 @@ class AgentLoop:
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
for cls in (WriteFileTool, EditFileTool, ListDirTool): for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
if self.exec_config.enable:
self.tools.register(ExecTool( self.tools.register(ExecTool(
working_dir=str(self.workspace), working_dir=str(self.workspace),
timeout=self.exec_config.timeout, timeout=self.exec_config.timeout,

View File

@@ -32,7 +32,9 @@ class SpawnTool(Tool):
return ( return (
"Spawn a subagent to handle a task in the background. " "Spawn a subagent to handle a task in the background. "
"Use this for complex or time-consuming tasks that can run independently. " "Use this for complex or time-consuming tasks that can run independently. "
"The subagent will complete the task and report back when done." "The subagent will complete the task and report back when done. "
"For deliverables or existing projects, inspect the workspace first "
"and use a dedicated subdirectory when helpful."
) )
@property @property

View File

@@ -80,6 +80,21 @@ class EmailChannel(BaseChannel):
"Nov", "Nov",
"Dec", "Dec",
) )
_IMAP_RECONNECT_MARKERS = (
"disconnected for inactivity",
"eof occurred in violation of protocol",
"socket error",
"connection reset",
"broken pipe",
"bye",
)
_IMAP_MISSING_MAILBOX_MARKERS = (
"mailbox doesn't exist",
"select failed",
"no such mailbox",
"can't open mailbox",
"does not exist",
)
@classmethod @classmethod
def default_config(cls) -> dict[str, Any]: def default_config(cls) -> dict[str, Any]:
@@ -267,8 +282,37 @@ class EmailChannel(BaseChannel):
dedupe: bool, dedupe: bool,
limit: int, limit: int,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Fetch messages by arbitrary IMAP search criteria."""
messages: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = []
cycle_uids: set[str] = set()
for attempt in range(2):
try:
self._fetch_messages_once(
search_criteria,
mark_seen,
dedupe,
limit,
messages,
cycle_uids,
)
return messages
except Exception as exc:
if attempt == 1 or not self._is_stale_imap_error(exc):
raise
logger.warning("Email IMAP connection went stale, retrying once: {}", exc)
return messages
def _fetch_messages_once(
self,
search_criteria: tuple[str, ...],
mark_seen: bool,
dedupe: bool,
limit: int,
messages: list[dict[str, Any]],
cycle_uids: set[str],
) -> None:
"""Fetch messages by arbitrary IMAP search criteria."""
mailbox = self.config.imap_mailbox or "INBOX" mailbox = self.config.imap_mailbox or "INBOX"
if self.config.imap_use_ssl: if self.config.imap_use_ssl:
@@ -278,8 +322,15 @@ class EmailChannel(BaseChannel):
try: try:
client.login(self.config.imap_username, self.config.imap_password) client.login(self.config.imap_username, self.config.imap_password)
try:
status, _ = client.select(mailbox) status, _ = client.select(mailbox)
except Exception as exc:
if self._is_missing_mailbox_error(exc):
logger.warning("Email mailbox unavailable, skipping poll for {}: {}", mailbox, exc)
return messages
raise
if status != "OK": if status != "OK":
logger.warning("Email mailbox select returned {}, skipping poll for {}", status, mailbox)
return messages return messages
status, data = client.search(None, *search_criteria) status, data = client.search(None, *search_criteria)
@@ -299,6 +350,8 @@ class EmailChannel(BaseChannel):
continue continue
uid = self._extract_uid(fetched) uid = self._extract_uid(fetched)
if uid and uid in cycle_uids:
continue
if dedupe and uid and uid in self._processed_uids: if dedupe and uid and uid in self._processed_uids:
continue continue
@@ -341,6 +394,8 @@ class EmailChannel(BaseChannel):
} }
) )
if uid:
cycle_uids.add(uid)
if dedupe and uid: if dedupe and uid:
self._processed_uids.add(uid) self._processed_uids.add(uid)
# mark_seen is the primary dedup; this set is a safety net # mark_seen is the primary dedup; this set is a safety net
@@ -356,7 +411,15 @@ class EmailChannel(BaseChannel):
except Exception: except Exception:
pass pass
return messages @classmethod
def _is_stale_imap_error(cls, exc: Exception) -> bool:
message = str(exc).lower()
return any(marker in message for marker in cls._IMAP_RECONNECT_MARKERS)
@classmethod
def _is_missing_mailbox_error(cls, exc: Exception) -> bool:
message = str(exc).lower()
return any(marker in message for marker in cls._IMAP_MISSING_MAILBOX_MARKERS)
@classmethod @classmethod
def _format_imap_date(cls, value: date) -> str: def _format_imap_date(cls, value: date) -> str:

View File

@@ -191,6 +191,10 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
texts.append(el.get("text", "")) texts.append(el.get("text", ""))
elif tag == "at": elif tag == "at":
texts.append(f"@{el.get('user_name', 'user')}") texts.append(f"@{el.get('user_name', 'user')}")
elif tag == "code_block":
lang = el.get("language", "")
code_text = el.get("text", "")
texts.append(f"\n```{lang}\n{code_text}\n```\n")
elif tag == "img" and (key := el.get("image_key")): elif tag == "img" and (key := el.get("image_key")):
images.append(key) images.append(key)
return (" ".join(texts).strip() or None), images return (" ".join(texts).strip() or None), images

View File

@@ -11,6 +11,7 @@ from typing import Any, Literal
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from telegram import BotCommand, ReplyParameters, Update from telegram import BotCommand, ReplyParameters, Update
from telegram.error import TimedOut
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
@@ -151,6 +152,10 @@ def _markdown_to_telegram_html(text: str) -> str:
return text return text
_SEND_MAX_RETRIES = 3
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
class TelegramConfig(Base): class TelegramConfig(Base):
"""Telegram channel configuration.""" """Telegram channel configuration."""
@@ -160,6 +165,8 @@ class TelegramConfig(Base):
proxy: str | None = None proxy: str | None = None
reply_to_message: bool = False reply_to_message: bool = False
group_policy: Literal["open", "mention"] = "mention" group_policy: Literal["open", "mention"] = "mention"
connection_pool_size: int = 32
pool_timeout: float = 5.0
class TelegramChannel(BaseChannel): class TelegramChannel(BaseChannel):
@@ -226,15 +233,29 @@ class TelegramChannel(BaseChannel):
self._running = True self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs proxy = self.config.proxy or None
req = HTTPXRequest(
connection_pool_size=16, # Separate pools so long-polling (getUpdates) never starves outbound sends.
pool_timeout=5.0, api_request = HTTPXRequest(
connection_pool_size=self.config.connection_pool_size,
pool_timeout=self.config.pool_timeout,
connect_timeout=30.0, connect_timeout=30.0,
read_timeout=30.0, read_timeout=30.0,
proxy=self.config.proxy if self.config.proxy else None, proxy=proxy,
)
poll_request = HTTPXRequest(
connection_pool_size=4,
pool_timeout=self.config.pool_timeout,
connect_timeout=30.0,
read_timeout=30.0,
proxy=proxy,
)
builder = (
Application.builder()
.token(self.config.token)
.request(api_request)
.get_updates_request(poll_request)
) )
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
self._app = builder.build() self._app = builder.build()
self._app.add_error_handler(self._on_error) self._app.add_error_handler(self._on_error)
@@ -365,7 +386,8 @@ class TelegramChannel(BaseChannel):
ok, error = validate_url_target(media_path) ok, error = validate_url_target(media_path)
if not ok: if not ok:
raise ValueError(f"unsafe media URL: {error}") raise ValueError(f"unsafe media URL: {error}")
await sender( await self._call_with_retry(
sender,
chat_id=chat_id, chat_id=chat_id,
**{param: media_path}, **{param: media_path},
reply_parameters=reply_params, reply_parameters=reply_params,
@@ -401,6 +423,21 @@ class TelegramChannel(BaseChannel):
else: else:
await self._send_text(chat_id, chunk, reply_params, thread_kwargs) await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
async def _call_with_retry(self, fn, *args, **kwargs):
"""Call an async Telegram API function with retry on pool/network timeout."""
for attempt in range(1, _SEND_MAX_RETRIES + 1):
try:
return await fn(*args, **kwargs)
except TimedOut:
if attempt == _SEND_MAX_RETRIES:
raise
delay = _SEND_RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.warning(
"Telegram timeout (attempt {}/{}), retrying in {:.1f}s",
attempt, _SEND_MAX_RETRIES, delay,
)
await asyncio.sleep(delay)
async def _send_text( async def _send_text(
self, self,
chat_id: int, chat_id: int,
@@ -411,7 +448,8 @@ class TelegramChannel(BaseChannel):
"""Send a plain text message with HTML fallback.""" """Send a plain text message with HTML fallback."""
try: try:
html = _markdown_to_telegram_html(text) html = _markdown_to_telegram_html(text)
await self._app.bot.send_message( await self._call_with_retry(
self._app.bot.send_message,
chat_id=chat_id, text=html, parse_mode="HTML", chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params, reply_parameters=reply_params,
**(thread_kwargs or {}), **(thread_kwargs or {}),
@@ -419,7 +457,8 @@ class TelegramChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e) logger.warning("HTML parse failed, falling back to plain text: {}", e)
try: try:
await self._app.bot.send_message( await self._call_with_retry(
self._app.bot.send_message,
chat_id=chat_id, chat_id=chat_id,
text=text, text=text,
reply_parameters=reply_params, reply_parameters=reply_params,

View File

@@ -21,12 +21,11 @@ if sys.platform == "win32":
pass pass
import typer import typer
from prompt_toolkit import print_formatted_text from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit import PromptSession from prompt_toolkit.application import run_in_terminal
from prompt_toolkit.formatted_text import ANSI, HTML from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.application import run_in_terminal
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
@@ -39,6 +38,7 @@ from nanobot.utils.helpers import sync_workspace_templates
app = typer.Typer( app = typer.Typer(
name="nanobot", name="nanobot",
context_settings={"help_option_names": ["-h", "--help"]},
help=f"{__logo__} nanobot - Personal AI Assistant", help=f"{__logo__} nanobot - Personal AI Assistant",
no_args_is_help=True, no_args_is_help=True,
) )
@@ -265,6 +265,7 @@ def main(
def onboard( def onboard(
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
wizard: bool = typer.Option(False, "--wizard", help="Use interactive wizard"),
): ):
"""Initialize nanobot configuration and workspace.""" """Initialize nanobot configuration and workspace."""
from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path
@@ -284,6 +285,9 @@ def onboard(
# Create or update config # Create or update config
if config_path.exists(): if config_path.exists():
if wizard:
config = _apply_workspace_override(load_config(config_path))
else:
console.print(f"[yellow]Config already exists at {config_path}[/yellow]") console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields") console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
@@ -297,26 +301,50 @@ def onboard(
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)") console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
else: else:
config = _apply_workspace_override(Config()) config = _apply_workspace_override(Config())
# In wizard mode, don't save yet - the wizard will handle saving if should_save=True
if not wizard:
save_config(config, config_path) save_config(config, config_path)
console.print(f"[green]✓[/green] Created config at {config_path}") console.print(f"[green]✓[/green] Created config at {config_path}")
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
# Run interactive wizard if enabled
if wizard:
from nanobot.cli.onboard_wizard import run_onboard
try:
result = run_onboard(initial_config=config)
if not result.should_save:
console.print("[yellow]Configuration discarded. No changes were saved.[/yellow]")
return
config = result.config
save_config(config, config_path)
console.print(f"[green]✓[/green] Config saved at {config_path}")
except Exception as e:
console.print(f"[red]✗[/red] Error during configuration: {e}")
console.print("[yellow]Please run 'nanobot onboard' again to complete setup.[/yellow]")
raise typer.Exit(1)
_onboard_plugins(config_path) _onboard_plugins(config_path)
# Create workspace, preferring the configured workspace path. # Create workspace, preferring the configured workspace path.
workspace = get_workspace_path(config.workspace_path) workspace_path = get_workspace_path(config.workspace_path)
if not workspace.exists(): if not workspace_path.exists():
workspace.mkdir(parents=True, exist_ok=True) workspace_path.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] Created workspace at {workspace}") console.print(f"[green]✓[/green] Created workspace at {workspace_path}")
sync_workspace_templates(workspace) sync_workspace_templates(workspace_path)
agent_cmd = 'nanobot agent -m "Hello!"' agent_cmd = 'nanobot agent -m "Hello!"'
gateway_cmd = "nanobot gateway"
if config: if config:
agent_cmd += f" --config {config_path}" agent_cmd += f" --config {config_path}"
gateway_cmd += f" --config {config_path}"
console.print(f"\n{__logo__} nanobot is ready!") console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:") console.print("\nNext steps:")
if wizard:
console.print(f" 1. Chat: [cyan]{agent_cmd}[/cyan]")
console.print(f" 2. Start gateway: [cyan]{gateway_cmd}[/cyan]")
else:
console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]") console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]")
console.print(" Get one at: https://openrouter.ai/keys") console.print(" Get one at: https://openrouter.ai/keys")
console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]") console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
@@ -363,9 +391,9 @@ def _onboard_plugins(config_path: Path) -> None:
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import GenerationSettings from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model model = config.agents.defaults.model
provider_name = config.get_provider_name(model) provider_name = config.get_provider_name(model)
@@ -434,21 +462,30 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
console.print(f"[dim]Using config: {config_path}[/dim]") console.print(f"[dim]Using config: {config_path}[/dim]")
loaded = load_config(config_path) loaded = load_config(config_path)
_warn_deprecated_config_keys(config_path)
if workspace: if workspace:
loaded.agents.defaults.workspace = workspace loaded.agents.defaults.workspace = workspace
return loaded return loaded
def _print_deprecated_memory_window_notice(config: Config) -> None: def _warn_deprecated_config_keys(config_path: Path | None) -> None:
"""Warn when running with old memoryWindow-only config.""" """Hint users to remove obsolete keys from their config file."""
if config.agents.defaults.should_warn_deprecated_memory_window: import json
from nanobot.config.loader import get_config_path
path = config_path or get_config_path()
try:
raw = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return
if "memoryWindow" in raw.get("agents", {}).get("defaults", {}):
console.print( console.print(
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without " "[dim]Hint: `memoryWindow` in your config is no longer used "
"`contextWindowTokens`. `memoryWindow` is ignored; run " "and can be safely removed.[/dim]"
"[cyan]nanobot onboard[/cyan] to refresh your config template."
) )
# ============================================================================ # ============================================================================
# Gateway / Server # Gateway / Server
# ============================================================================ # ============================================================================
@@ -476,7 +513,6 @@ def gateway(
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
config = _load_runtime_config(config, workspace) config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
port = port if port is not None else config.gateway.port port = port if port is not None else config.gateway.port
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...") console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
@@ -667,7 +703,6 @@ def agent(
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
config = _load_runtime_config(config, workspace) config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()

231
nanobot/cli/model_info.py Normal file
View File

@@ -0,0 +1,231 @@
"""Model information helpers for the onboard wizard.
Provides model context window lookup and autocomplete suggestions using litellm.
"""
from __future__ import annotations
from functools import lru_cache
from typing import Any
def _litellm():
"""Lazy accessor for litellm (heavy import deferred until actually needed)."""
import litellm as _ll
return _ll
@lru_cache(maxsize=1)
def _get_model_cost_map() -> dict[str, Any]:
"""Get litellm's model cost map (cached)."""
return getattr(_litellm(), "model_cost", {})
@lru_cache(maxsize=1)
def get_all_models() -> list[str]:
"""Get all known model names from litellm.
"""
models = set()
# From model_cost (has pricing info)
cost_map = _get_model_cost_map()
for k in cost_map.keys():
if k != "sample_spec":
models.add(k)
# From models_by_provider (more complete provider coverage)
for provider_models in getattr(_litellm(), "models_by_provider", {}).values():
if isinstance(provider_models, (set, list)):
models.update(provider_models)
return sorted(models)
def _normalize_model_name(model: str) -> str:
"""Normalize model name for comparison."""
return model.lower().replace("-", "_").replace(".", "")
def find_model_info(model_name: str) -> dict[str, Any] | None:
"""Find model info with fuzzy matching.
Args:
model_name: Model name in any common format
Returns:
Model info dict or None if not found
"""
cost_map = _get_model_cost_map()
if not cost_map:
return None
# Direct match
if model_name in cost_map:
return cost_map[model_name]
# Extract base name (without provider prefix)
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
base_normalized = _normalize_model_name(base_name)
candidates = []
for key, info in cost_map.items():
if key == "sample_spec":
continue
key_base = key.split("/")[-1] if "/" in key else key
key_base_normalized = _normalize_model_name(key_base)
# Score the match
score = 0
# Exact base name match (highest priority)
if base_normalized == key_base_normalized:
score = 100
# Base name contains model
elif base_normalized in key_base_normalized:
score = 80
# Model contains base name
elif key_base_normalized in base_normalized:
score = 70
# Partial match
elif base_normalized[:10] in key_base_normalized:
score = 50
if score > 0:
# Prefer models with max_input_tokens
if info.get("max_input_tokens"):
score += 10
candidates.append((score, key, info))
if not candidates:
return None
# Return the best match
candidates.sort(key=lambda x: (-x[0], x[1]))
return candidates[0][2]
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
"""Get the maximum input context tokens for a model.
Args:
model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o")
provider: Provider name for informational purposes (not yet used for filtering)
Returns:
Maximum input tokens, or None if unknown
Note:
The provider parameter is currently informational only. Future versions may
use it to prefer provider-specific model variants in the lookup.
"""
# First try fuzzy search in model_cost (has more accurate max_input_tokens)
info = find_model_info(model)
if info:
# Prefer max_input_tokens (this is what we want for context window)
max_input = info.get("max_input_tokens")
if max_input and isinstance(max_input, int):
return max_input
# Fall back to litellm's get_max_tokens (returns max_output_tokens typically)
try:
result = _litellm().get_max_tokens(model)
if result and result > 0:
return result
except (KeyError, ValueError, AttributeError):
# Model not found in litellm's database or invalid response
pass
# Last resort: use max_tokens from model_cost
if info:
max_tokens = info.get("max_tokens")
if max_tokens and isinstance(max_tokens, int):
return max_tokens
return None
@lru_cache(maxsize=1)
def _get_provider_keywords() -> dict[str, list[str]]:
"""Build provider keywords mapping from nanobot's provider registry.
Returns:
Dict mapping provider name to list of keywords for model filtering.
"""
try:
from nanobot.providers.registry import PROVIDERS
mapping = {}
for spec in PROVIDERS:
if spec.keywords:
mapping[spec.name] = list(spec.keywords)
return mapping
except ImportError:
return {}
def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
"""Get autocomplete suggestions for model names.
Args:
partial: Partial model name typed by user
provider: Provider name for filtering (e.g., "openrouter", "minimax")
limit: Maximum number of suggestions to return
Returns:
List of matching model names
"""
all_models = get_all_models()
if not all_models:
return []
partial_lower = partial.lower()
partial_normalized = _normalize_model_name(partial)
# Get provider keywords from registry
provider_keywords = _get_provider_keywords()
# Filter by provider if specified
allowed_keywords = None
if provider and provider != "auto":
allowed_keywords = provider_keywords.get(provider.lower())
matches = []
for model in all_models:
model_lower = model.lower()
# Apply provider filter
if allowed_keywords:
if not any(kw in model_lower for kw in allowed_keywords):
continue
# Match against partial input
if not partial:
matches.append(model)
continue
if partial_lower in model_lower:
# Score by position of match (earlier = better)
pos = model_lower.find(partial_lower)
score = 100 - pos
matches.append((score, model))
elif partial_normalized in _normalize_model_name(model):
score = 50
matches.append((score, model))
# Sort by score if we have scored matches
if matches and isinstance(matches[0], tuple):
matches.sort(key=lambda x: (-x[0], x[1]))
matches = [m[1] for m in matches]
else:
matches.sort()
return matches[:limit]
def format_token_count(tokens: int) -> str:
"""Format token count for display (e.g., 200000 -> '200,000')."""
return f"{tokens:,}"

File diff suppressed because it is too large Load Diff

View File

@@ -3,8 +3,10 @@
import json import json
from pathlib import Path from pathlib import Path
from nanobot.config.schema import Config import pydantic
from loguru import logger
from nanobot.config.schema import Config
# Global variable to store current config path (for multi-instance support) # Global variable to store current config path (for multi-instance support)
_current_config_path: Path | None = None _current_config_path: Path | None = None
@@ -41,9 +43,9 @@ def load_config(config_path: Path | None = None) -> Config:
data = json.load(f) data = json.load(f)
data = _migrate_config(data) data = _migrate_config(data)
return Config.model_validate(data) return Config.model_validate(data)
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError, pydantic.ValidationError) as e:
print(f"Warning: Failed to load config from {path}: {e}") logger.warning(f"Failed to load config from {path}: {e}")
print("Using default configuration.") logger.warning("Using default configuration.")
return Config() return Config()
@@ -59,7 +61,7 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
path = config_path or get_config_path() path = config_path or get_config_path()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
data = config.model_dump(by_alias=True) data = config.model_dump(mode="json", by_alias=True)
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False) json.dump(data, f, indent=2, ensure_ascii=False)

View File

@@ -38,14 +38,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
temperature: float = 0.1 temperature: float = 0.1
max_tool_iterations: int = 40 max_tool_iterations: int = 40
# Deprecated compatibility field: accepted from old configs but ignored at runtime. reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
memory_window: int | None = Field(default=None, exclude=True)
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
@property
def should_warn_deprecated_memory_window(self) -> bool:
"""Return True when old memoryWindow is present without contextWindowTokens."""
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
class AgentsConfig(Base): class AgentsConfig(Base):
@@ -85,8 +78,8 @@ class ProvidersConfig(Base):
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international) byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) openai_codex: ProviderConfig = Field(default_factory=ProviderConfig, exclude=True) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) github_copilot: ProviderConfig = Field(default_factory=ProviderConfig, exclude=True) # Github Copilot (OAuth)
class HeartbeatConfig(Base): class HeartbeatConfig(Base):
@@ -125,10 +118,10 @@ class WebToolsConfig(Base):
class ExecToolConfig(Base): class ExecToolConfig(Base):
"""Shell exec tool configuration.""" """Shell exec tool configuration."""
enable: bool = True
timeout: int = 60 timeout: int = 60
path_append: str = "" path_append: str = ""
class MCPServerConfig(Base): class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP).""" """MCP server connection configuration (stdio or HTTP)."""

View File

@@ -10,7 +10,7 @@ from typing import Any, Callable, Coroutine
from loguru import logger from loguru import logger
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule, CronStore from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
def _now_ms() -> int: def _now_ms() -> int:
@@ -63,10 +63,12 @@ def _validate_schedule_for_add(schedule: CronSchedule) -> None:
class CronService: class CronService:
"""Service for managing and executing scheduled jobs.""" """Service for managing and executing scheduled jobs."""
_MAX_RUN_HISTORY = 20
def __init__( def __init__(
self, self,
store_path: Path, store_path: Path,
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
): ):
self.store_path = store_path self.store_path = store_path
self.on_job = on_job self.on_job = on_job
@@ -113,6 +115,15 @@ class CronService:
last_run_at_ms=j.get("state", {}).get("lastRunAtMs"), last_run_at_ms=j.get("state", {}).get("lastRunAtMs"),
last_status=j.get("state", {}).get("lastStatus"), last_status=j.get("state", {}).get("lastStatus"),
last_error=j.get("state", {}).get("lastError"), last_error=j.get("state", {}).get("lastError"),
run_history=[
CronRunRecord(
run_at_ms=r["runAtMs"],
status=r["status"],
duration_ms=r.get("durationMs", 0),
error=r.get("error"),
)
for r in j.get("state", {}).get("runHistory", [])
],
), ),
created_at_ms=j.get("createdAtMs", 0), created_at_ms=j.get("createdAtMs", 0),
updated_at_ms=j.get("updatedAtMs", 0), updated_at_ms=j.get("updatedAtMs", 0),
@@ -160,6 +171,15 @@ class CronService:
"lastRunAtMs": j.state.last_run_at_ms, "lastRunAtMs": j.state.last_run_at_ms,
"lastStatus": j.state.last_status, "lastStatus": j.state.last_status,
"lastError": j.state.last_error, "lastError": j.state.last_error,
"runHistory": [
{
"runAtMs": r.run_at_ms,
"status": r.status,
"durationMs": r.duration_ms,
"error": r.error,
}
for r in j.state.run_history
],
}, },
"createdAtMs": j.created_at_ms, "createdAtMs": j.created_at_ms,
"updatedAtMs": j.updated_at_ms, "updatedAtMs": j.updated_at_ms,
@@ -248,9 +268,8 @@ class CronService:
logger.info("Cron: executing job '{}' ({})", job.name, job.id) logger.info("Cron: executing job '{}' ({})", job.name, job.id)
try: try:
response = None
if self.on_job: if self.on_job:
response = await self.on_job(job) await self.on_job(job)
job.state.last_status = "ok" job.state.last_status = "ok"
job.state.last_error = None job.state.last_error = None
@@ -261,8 +280,17 @@ class CronService:
job.state.last_error = str(e) job.state.last_error = str(e)
logger.error("Cron: job '{}' failed: {}", job.name, e) logger.error("Cron: job '{}' failed: {}", job.name, e)
end_ms = _now_ms()
job.state.last_run_at_ms = start_ms job.state.last_run_at_ms = start_ms
job.updated_at_ms = _now_ms() job.updated_at_ms = end_ms
job.state.run_history.append(CronRunRecord(
run_at_ms=start_ms,
status=job.state.last_status,
duration_ms=end_ms - start_ms,
error=job.state.last_error,
))
job.state.run_history = job.state.run_history[-self._MAX_RUN_HISTORY:]
# Handle one-shot jobs # Handle one-shot jobs
if job.schedule.kind == "at": if job.schedule.kind == "at":
@@ -366,6 +394,11 @@ class CronService:
return True return True
return False return False
def get_job(self, job_id: str) -> CronJob | None:
"""Get a job by ID."""
store = self._load_store()
return next((j for j in store.jobs if j.id == job_id), None)
def status(self) -> dict: def status(self) -> dict:
"""Get service status.""" """Get service status."""
store = self._load_store() store = self._load_store()

View File

@@ -29,6 +29,15 @@ class CronPayload:
to: str | None = None # e.g. phone number to: str | None = None # e.g. phone number
@dataclass
class CronRunRecord:
"""A single execution record for a cron job."""
run_at_ms: int
status: Literal["ok", "error", "skipped"]
duration_ms: int = 0
error: str | None = None
@dataclass @dataclass
class CronJobState: class CronJobState:
"""Runtime state of a job.""" """Runtime state of a job."""
@@ -36,6 +45,7 @@ class CronJobState:
last_run_at_ms: int | None = None last_run_at_ms: int | None = None
last_status: Literal["ok", "error", "skipped"] | None = None last_status: Literal["ok", "error", "skipped"] | None = None
last_error: str | None = None last_error: str | None = None
run_history: list[CronRunRecord] = field(default_factory=list)
@dataclass @dataclass

View File

@@ -51,6 +51,12 @@ class CustomProvider(LLMProvider):
try: try:
return self._parse(await self._client.chat.completions.create(**kwargs)) return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e: except Exception as e:
# JSONDecodeError.doc / APIError.response.text may carry the raw body
# (e.g. "unsupported model: xxx") which is far more useful than the
# generic "Expecting value …" message. Truncate to avoid huge HTML pages.
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
if body and body.strip():
return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error")
return LLMResponse(content=f"Error: {e}", finish_reason="error") return LLMResponse(content=f"Error: {e}", finish_reason="error")
def _parse(self, response: Any) -> LLMResponse: def _parse(self, response: Any) -> LLMResponse:

View File

@@ -42,6 +42,7 @@ dependencies = [
"qq-botpy>=1.2.0,<2.0.0", "qq-botpy>=1.2.0,<2.0.0",
"python-socks[asyncio]>=2.8.0,<3.0.0", "python-socks[asyncio]>=2.8.0,<3.0.0",
"prompt-toolkit>=3.0.50,<4.0.0", "prompt-toolkit>=3.0.50,<4.0.0",
"questionary>=2.0.0,<3.0.0",
"mcp>=1.26.0,<2.0.0", "mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0", "json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0", "chardet>=3.0.2,<6.0.0",

View File

@@ -1,6 +1,5 @@
import json import json
import re import re
import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -13,19 +12,18 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model from nanobot.providers.registry import find_by_model
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
runner = CliRunner() runner = CliRunner()
class _StopGateway(RuntimeError): class _StopGatewayError(RuntimeError):
pass pass
import shutil
import pytest
@pytest.fixture @pytest.fixture
def mock_paths(): def mock_paths():
"""Mock config/workspace paths for test isolation.""" """Mock config/workspace paths for test isolation."""
@@ -117,6 +115,12 @@ def test_onboard_existing_workspace_safe_create(mock_paths):
assert (workspace_dir / "AGENTS.md").exists() assert (workspace_dir / "AGENTS.md").exists()
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
def test_onboard_help_shows_workspace_and_config_options(): def test_onboard_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["onboard", "--help"]) result = runner.invoke(app, ["onboard", "--help"])
@@ -126,9 +130,28 @@ def test_onboard_help_shows_workspace_and_config_options():
assert "-w" in stripped_output assert "-w" in stripped_output
assert "--config" in stripped_output assert "--config" in stripped_output
assert "-c" in stripped_output assert "-c" in stripped_output
assert "--wizard" in stripped_output
assert "--dir" not in stripped_output assert "--dir" not in stripped_output
def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch):
config_file, workspace_dir, _ = mock_paths
from nanobot.cli.onboard_wizard import OnboardResult
monkeypatch.setattr(
"nanobot.cli.onboard_wizard.run_onboard",
lambda initial_config: OnboardResult(config=initial_config, should_save=False),
)
result = runner.invoke(app, ["onboard", "--wizard"])
assert result.exit_code == 0
assert "No changes were saved" in result.stdout
assert not config_file.exists()
assert not workspace_dir.exists()
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch): def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
config_path = tmp_path / "instance" / "config.json" config_path = tmp_path / "instance" / "config.json"
workspace_path = tmp_path / "workspace" workspace_path = tmp_path / "workspace"
@@ -151,6 +174,31 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch)
assert f"--config {resolved_config}" in compact_output assert f"--config {resolved_config}" in compact_output
def test_onboard_wizard_preserves_explicit_config_in_next_steps(tmp_path, monkeypatch):
config_path = tmp_path / "instance" / "config.json"
workspace_path = tmp_path / "workspace"
from nanobot.cli.onboard_wizard import OnboardResult
monkeypatch.setattr(
"nanobot.cli.onboard_wizard.run_onboard",
lambda initial_config: OnboardResult(config=initial_config, should_save=True),
)
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(
app,
["onboard", "--wizard", "--config", str(config_path), "--workspace", str(workspace_path)],
)
assert result.exit_code == 0
stripped_output = _strip_ansi(result.stdout)
compact_output = stripped_output.replace("\n", "")
resolved_config = str(config_path.resolve())
assert f'nanobot agent -m "Hello!" --config {resolved_config}' in compact_output
assert f"nanobot gateway --config {resolved_config}" in compact_output
def test_config_matches_github_copilot_codex_with_hyphen_prefix(): def test_config_matches_github_copilot_codex_with_hyphen_prefix():
config = Config() config = Config()
config.agents.defaults.model = "github-copilot/gpt-5.3-codex" config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
@@ -165,6 +213,15 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex" assert config.get_provider_name() == "openai_codex"
def test_config_dump_excludes_oauth_provider_blocks():
config = Config()
providers = config.model_dump(by_alias=True)["providers"]
assert "openaiCodex" not in providers
assert "githubCopilot" not in providers
def test_config_matches_explicit_ollama_prefix_without_api_key(): def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config() config = Config()
config.agents.defaults.model = "ollama/llama3.2" config.agents.defaults.model = "ollama/llama3.2"
@@ -404,14 +461,15 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime): def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
mock_agent_runtime["config"].agents.defaults.memory_window = 100 config_file = tmp_path / "config.json"
config_file.write_text(json.dumps({"agents": {"defaults": {"memoryWindow": 42}}}))
result = runner.invoke(app, ["agent", "-m", "hello"]) result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0 assert result.exit_code == 0
assert "memoryWindow" in result.stdout assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout assert "no longer used" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
@@ -434,12 +492,12 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
) )
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.cli.commands._make_provider", "nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
) )
result = runner.invoke(app, ["gateway", "--config", str(config_file)]) result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway) assert isinstance(result.exception, _StopGatewayError)
assert seen["config_path"] == config_file.resolve() assert seen["config_path"] == config_file.resolve()
assert seen["workspace"] == Path(config.agents.defaults.workspace) assert seen["workspace"] == Path(config.agents.defaults.workspace)
@@ -462,7 +520,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
) )
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.cli.commands._make_provider", "nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
) )
result = runner.invoke( result = runner.invoke(
@@ -470,33 +528,11 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
["gateway", "--config", str(config_file), "--workspace", str(override)], ["gateway", "--config", str(config_file), "--workspace", str(override)],
) )
assert isinstance(result.exception, _StopGateway) assert isinstance(result.exception, _StopGatewayError)
assert seen["workspace"] == override assert seen["workspace"] == override
assert config.workspace_path == override assert config.workspace_path == override
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.memory_window = 100
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json" config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True) config_file.parent.mkdir(parents=True)
@@ -517,13 +553,13 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
class _StopCron: class _StopCron:
def __init__(self, store_path: Path) -> None: def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path seen["cron_store"] = store_path
raise _StopGateway("stop") raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(app, ["gateway", "--config", str(config_file)]) result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway) assert isinstance(result.exception, _StopGatewayError)
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
@@ -540,12 +576,12 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.cli.commands._make_provider", "nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
) )
result = runner.invoke(app, ["gateway", "--config", str(config_file)]) result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway) assert isinstance(result.exception, _StopGatewayError)
assert "port 18791" in result.stdout assert "port 18791" in result.stdout
@@ -562,10 +598,10 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.cli.commands._make_provider", "nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
) )
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"]) result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
assert isinstance(result.exception, _StopGateway) assert isinstance(result.exception, _StopGatewayError)
assert "port 18792" in result.stdout assert "port 18792" in result.stdout

View File

@@ -1,15 +1,9 @@
import json import json
from types import SimpleNamespace
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config from nanobot.config.loader import load_config, save_config
runner = CliRunner()
def test_load_config_keeps_max_tokens_and_ignores_legacy_memory_window(tmp_path) -> None:
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json" config_path = tmp_path / "config.json"
config_path.write_text( config_path.write_text(
json.dumps( json.dumps(
@@ -29,7 +23,7 @@ def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path
assert config.agents.defaults.max_tokens == 1234 assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536 assert config.agents.defaults.context_window_tokens == 65_536
assert config.agents.defaults.should_warn_deprecated_memory_window is True assert not hasattr(config.agents.defaults, "memory_window")
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None: def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
@@ -58,7 +52,7 @@ def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path
assert "memoryWindow" not in defaults assert "memoryWindow" not in defaults
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None: def test_onboard_does_not_crash_with_legacy_memory_window(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json" config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace" workspace = tmp_path / "workspace"
config_path.write_text( config_path.write_text(
@@ -78,18 +72,17 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
from typer.testing import CliRunner
from nanobot.cli.commands import app
runner = CliRunner()
result = runner.invoke(app, ["onboard"], input="n\n") result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0 assert result.exit_code == 0
assert "contextWindowTokens" in result.stdout
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 3333
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None: def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
from types import SimpleNamespace
config_path = tmp_path / "config.json" config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace" workspace = tmp_path / "workspace"
config_path.write_text( config_path.write_text(
@@ -125,6 +118,9 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
}, },
) )
from typer.testing import CliRunner
from nanobot.cli.commands import app
runner = CliRunner()
result = runner.invoke(app, ["onboard"], input="n\n") result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -182,7 +182,7 @@ class TestConsolidationTriggerConditions:
"""Test consolidation trigger conditions and logic.""" """Test consolidation trigger conditions and logic."""
def test_consolidation_needed_when_messages_exceed_window(self): def test_consolidation_needed_when_messages_exceed_window(self):
"""Test consolidation logic: should trigger when messages > memory_window.""" """Test consolidation logic: should trigger when messages exceed the window."""
session = create_session_with_messages("test:trigger", 60) session = create_session_with_messages("test:trigger", 60)
total_messages = len(session.messages) total_messages = len(session.messages)

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import json
import pytest import pytest
@@ -32,6 +33,87 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
assert job.state.next_run_at_ms is not None assert job.state.next_run_at_ms is not None
@pytest.mark.asyncio
async def test_execute_job_records_run_history(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="hist",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
loaded = service.get_job(job.id)
assert loaded is not None
assert len(loaded.state.run_history) == 1
rec = loaded.state.run_history[0]
assert rec.status == "ok"
assert rec.duration_ms >= 0
assert rec.error is None
@pytest.mark.asyncio
async def test_run_history_records_errors(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
async def fail(_):
raise RuntimeError("boom")
service = CronService(store_path, on_job=fail)
job = service.add_job(
name="fail",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
loaded = service.get_job(job.id)
assert len(loaded.state.run_history) == 1
assert loaded.state.run_history[0].status == "error"
assert loaded.state.run_history[0].error == "boom"
@pytest.mark.asyncio
async def test_run_history_trimmed_to_max(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="trim",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
for _ in range(25):
await service.run_job(job.id)
loaded = service.get_job(job.id)
assert len(loaded.state.run_history) == CronService._MAX_RUN_HISTORY
@pytest.mark.asyncio
async def test_run_history_persisted_to_disk(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="persist",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
raw = json.loads(store_path.read_text())
history = raw["jobs"][0]["state"]["runHistory"]
assert len(history) == 1
assert history[0]["status"] == "ok"
assert "runAtMs" in history[0]
assert "durationMs" in history[0]
fresh = CronService(store_path)
loaded = fresh.get_job(job.id)
assert len(loaded.state.run_history) == 1
assert loaded.state.run_history[0].status == "ok"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_running_service_honors_external_disable(tmp_path) -> None: async def test_running_service_honors_external_disable(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json" store_path = tmp_path / "cron" / "jobs.json"

View File

@@ -1,5 +1,6 @@
from email.message import EmailMessage from email.message import EmailMessage
from datetime import date from datetime import date
import imaplib
import pytest import pytest
@@ -82,6 +83,120 @@ def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
assert items_again == [] assert items_again == []
def test_fetch_new_messages_retries_once_when_imap_connection_goes_stale(monkeypatch) -> None:
raw = _make_raw_email(subject="Invoice", body="Please pay")
fail_once = {"pending": True}
class FlakyIMAP:
def __init__(self) -> None:
self.store_calls: list[tuple[bytes, str, str]] = []
self.search_calls = 0
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"1"]
def search(self, *_args):
self.search_calls += 1
if fail_once["pending"]:
fail_once["pending"] = False
raise imaplib.IMAP4.abort("socket error")
return "OK", [b"1"]
def fetch(self, _imap_id: bytes, _parts: str):
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
def store(self, imap_id: bytes, op: str, flags: str):
self.store_calls.append((imap_id, op, flags))
return "OK", [b""]
def logout(self):
return "BYE", [b""]
fake_instances: list[FlakyIMAP] = []
def _factory(_host: str, _port: int):
instance = FlakyIMAP()
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", _factory)
channel = EmailChannel(_make_config(), MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert len(fake_instances) == 2
assert fake_instances[0].search_calls == 1
assert fake_instances[1].search_calls == 1
def test_fetch_new_messages_keeps_messages_collected_before_stale_retry(monkeypatch) -> None:
raw_first = _make_raw_email(subject="First", body="First body")
raw_second = _make_raw_email(subject="Second", body="Second body")
mailbox_state = {
b"1": {"uid": b"123", "raw": raw_first, "seen": False},
b"2": {"uid": b"124", "raw": raw_second, "seen": False},
}
fail_once = {"pending": True}
class FlakyIMAP:
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
return "OK", [b"2"]
def search(self, *_args):
unseen_ids = [imap_id for imap_id, item in mailbox_state.items() if not item["seen"]]
return "OK", [b" ".join(unseen_ids)]
def fetch(self, imap_id: bytes, _parts: str):
if imap_id == b"2" and fail_once["pending"]:
fail_once["pending"] = False
raise imaplib.IMAP4.abort("socket error")
item = mailbox_state[imap_id]
header = b"%s (UID %s BODY[] {200})" % (imap_id, item["uid"])
return "OK", [(header, item["raw"]), b")"]
def store(self, imap_id: bytes, _op: str, _flags: str):
mailbox_state[imap_id]["seen"] = True
return "OK", [b""]
def logout(self):
return "BYE", [b""]
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: FlakyIMAP())
channel = EmailChannel(_make_config(), MessageBus())
items = channel._fetch_new_messages()
assert [item["subject"] for item in items] == ["First", "Second"]
def test_fetch_new_messages_skips_missing_mailbox(monkeypatch) -> None:
class MissingMailboxIMAP:
def login(self, _user: str, _pw: str):
return "OK", [b"logged in"]
def select(self, _mailbox: str):
raise imaplib.IMAP4.error("Mailbox doesn't exist")
def logout(self):
return "BYE", [b""]
monkeypatch.setattr(
"nanobot.channels.email.imaplib.IMAP4_SSL",
lambda _h, _p: MissingMailboxIMAP(),
)
channel = EmailChannel(_make_config(), MessageBus())
assert channel._fetch_new_messages() == []
def test_extract_text_body_falls_back_to_html() -> None: def test_extract_text_body_falls_back_to_html() -> None:
msg = EmailMessage() msg = EmailMessage()
msg["From"] = "alice@example.com" msg["From"] = "alice@example.com"

495
tests/test_onboard_logic.py Normal file
View File

@@ -0,0 +1,495 @@
"""Unit tests for onboard core logic functions.
These tests focus on the business logic behind the onboard wizard,
without testing the interactive UI components.
"""
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any, cast
import pytest
from pydantic import BaseModel, Field
from nanobot.cli import onboard_wizard
# Import functions to test
from nanobot.cli.commands import _merge_missing_defaults
from nanobot.cli.onboard_wizard import (
_BACK_PRESSED,
_configure_pydantic_model,
_format_value,
_get_field_display_name,
_get_field_type_info,
run_onboard,
)
from nanobot.config.schema import Config
from nanobot.utils.helpers import sync_workspace_templates
class TestMergeMissingDefaults:
"""Tests for _merge_missing_defaults recursive config merging."""
def test_adds_missing_top_level_keys(self):
existing = {"a": 1}
defaults = {"a": 1, "b": 2, "c": 3}
result = _merge_missing_defaults(existing, defaults)
assert result == {"a": 1, "b": 2, "c": 3}
def test_preserves_existing_values(self):
existing = {"a": "custom_value"}
defaults = {"a": "default_value"}
result = _merge_missing_defaults(existing, defaults)
assert result == {"a": "custom_value"}
def test_merges_nested_dicts_recursively(self):
existing = {
"level1": {
"level2": {
"existing": "kept",
}
}
}
defaults = {
"level1": {
"level2": {
"existing": "replaced",
"added": "new",
},
"level2b": "also_new",
}
}
result = _merge_missing_defaults(existing, defaults)
assert result == {
"level1": {
"level2": {
"existing": "kept",
"added": "new",
},
"level2b": "also_new",
}
}
def test_returns_existing_if_not_dict(self):
assert _merge_missing_defaults("string", {"a": 1}) == "string"
assert _merge_missing_defaults([1, 2, 3], {"a": 1}) == [1, 2, 3]
assert _merge_missing_defaults(None, {"a": 1}) is None
assert _merge_missing_defaults(42, {"a": 1}) == 42
def test_returns_existing_if_defaults_not_dict(self):
assert _merge_missing_defaults({"a": 1}, "string") == {"a": 1}
assert _merge_missing_defaults({"a": 1}, None) == {"a": 1}
def test_handles_empty_dicts(self):
assert _merge_missing_defaults({}, {"a": 1}) == {"a": 1}
assert _merge_missing_defaults({"a": 1}, {}) == {"a": 1}
assert _merge_missing_defaults({}, {}) == {}
def test_backfills_channel_config(self):
"""Real-world scenario: backfill missing channel fields."""
existing_channel = {
"enabled": False,
"appId": "",
"secret": "",
}
default_channel = {
"enabled": False,
"appId": "",
"secret": "",
"msgFormat": "plain",
"allowFrom": [],
}
result = _merge_missing_defaults(existing_channel, default_channel)
assert result["msgFormat"] == "plain"
assert result["allowFrom"] == []
class TestGetFieldTypeInfo:
"""Tests for _get_field_type_info type extraction."""
def test_extracts_str_type(self):
class Model(BaseModel):
field: str
type_name, inner = _get_field_type_info(Model.model_fields["field"])
assert type_name == "str"
assert inner is None
def test_extracts_int_type(self):
class Model(BaseModel):
count: int
type_name, inner = _get_field_type_info(Model.model_fields["count"])
assert type_name == "int"
assert inner is None
def test_extracts_bool_type(self):
class Model(BaseModel):
enabled: bool
type_name, inner = _get_field_type_info(Model.model_fields["enabled"])
assert type_name == "bool"
assert inner is None
def test_extracts_float_type(self):
class Model(BaseModel):
ratio: float
type_name, inner = _get_field_type_info(Model.model_fields["ratio"])
assert type_name == "float"
assert inner is None
def test_extracts_list_type_with_item_type(self):
class Model(BaseModel):
items: list[str]
type_name, inner = _get_field_type_info(Model.model_fields["items"])
assert type_name == "list"
assert inner is str
def test_extracts_list_type_without_item_type(self):
# Plain list without type param falls back to str
class Model(BaseModel):
items: list # type: ignore
# Plain list annotation doesn't match list check, returns str
type_name, inner = _get_field_type_info(Model.model_fields["items"])
assert type_name == "str" # Falls back to str for untyped list
assert inner is None
def test_extracts_dict_type(self):
# Plain dict without type param falls back to str
class Model(BaseModel):
data: dict # type: ignore
# Plain dict annotation doesn't match dict check, returns str
type_name, inner = _get_field_type_info(Model.model_fields["data"])
assert type_name == "str" # Falls back to str for untyped dict
assert inner is None
def test_extracts_optional_type(self):
class Model(BaseModel):
optional: str | None = None
type_name, inner = _get_field_type_info(Model.model_fields["optional"])
# Should unwrap Optional and get str
assert type_name == "str"
assert inner is None
def test_extracts_nested_model_type(self):
class Inner(BaseModel):
x: int
class Outer(BaseModel):
nested: Inner
type_name, inner = _get_field_type_info(Outer.model_fields["nested"])
assert type_name == "model"
assert inner is Inner
def test_handles_none_annotation(self):
"""Field with None annotation defaults to str."""
class Model(BaseModel):
field: Any = None
# Create a mock field_info with None annotation
field_info = SimpleNamespace(annotation=None)
type_name, inner = _get_field_type_info(field_info)
assert type_name == "str"
assert inner is None
class TestGetFieldDisplayName:
"""Tests for _get_field_display_name human-readable name generation."""
def test_uses_description_if_present(self):
class Model(BaseModel):
api_key: str = Field(description="API Key for authentication")
name = _get_field_display_name("api_key", Model.model_fields["api_key"])
assert name == "API Key for authentication"
def test_converts_snake_case_to_title(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("user_name", field_info)
assert name == "User Name"
def test_adds_url_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("api_url", field_info)
# Title case: "Api Url"
assert "Url" in name and "Api" in name
def test_adds_path_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("file_path", field_info)
assert "Path" in name and "File" in name
def test_adds_id_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("user_id", field_info)
# Title case: "User Id"
assert "Id" in name and "User" in name
def test_adds_key_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("api_key", field_info)
assert "Key" in name and "Api" in name
def test_adds_token_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("auth_token", field_info)
assert "Token" in name and "Auth" in name
def test_adds_seconds_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("timeout_s", field_info)
# Contains "(Seconds)" with title case
assert "(Seconds)" in name or "(seconds)" in name
def test_adds_ms_suffix(self):
field_info = SimpleNamespace(description=None)
name = _get_field_display_name("delay_ms", field_info)
# Contains "(Ms)" or "(ms)"
assert "(Ms)" in name or "(ms)" in name
class TestFormatValue:
"""Tests for _format_value display formatting."""
def test_formats_none_as_not_set(self):
assert "not set" in _format_value(None)
def test_formats_empty_string_as_not_set(self):
assert "not set" in _format_value("")
def test_formats_empty_dict_as_not_set(self):
assert "not set" in _format_value({})
def test_formats_empty_list_as_not_set(self):
assert "not set" in _format_value([])
def test_formats_string_value(self):
result = _format_value("hello")
assert "hello" in result
def test_formats_list_value(self):
result = _format_value(["a", "b"])
assert "a" in result or "b" in result
def test_formats_dict_value(self):
result = _format_value({"key": "value"})
assert "key" in result or "value" in result
def test_formats_int_value(self):
result = _format_value(42)
assert "42" in result
def test_formats_bool_true(self):
result = _format_value(True)
assert "true" in result.lower() or "" in result
def test_formats_bool_false(self):
result = _format_value(False)
assert "false" in result.lower() or "" in result
class TestSyncWorkspaceTemplates:
"""Tests for sync_workspace_templates file synchronization."""
def test_creates_missing_files(self, tmp_path):
"""Should create template files that don't exist."""
workspace = tmp_path / "workspace"
added = sync_workspace_templates(workspace, silent=True)
# Check that some files were created
assert isinstance(added, list)
# The actual files depend on the templates directory
def test_does_not_overwrite_existing_files(self, tmp_path):
"""Should not overwrite files that already exist."""
workspace = tmp_path / "workspace"
workspace.mkdir(parents=True)
(workspace / "AGENTS.md").write_text("existing content")
sync_workspace_templates(workspace, silent=True)
# Existing file should not be changed
content = (workspace / "AGENTS.md").read_text()
assert content == "existing content"
def test_creates_memory_directory(self, tmp_path):
"""Should create memory directory structure."""
workspace = tmp_path / "workspace"
sync_workspace_templates(workspace, silent=True)
assert (workspace / "memory").exists() or (workspace / "skills").exists()
def test_returns_list_of_added_files(self, tmp_path):
"""Should return list of relative paths for added files."""
workspace = tmp_path / "workspace"
added = sync_workspace_templates(workspace, silent=True)
assert isinstance(added, list)
# All paths should be relative to workspace
for path in added:
assert not Path(path).is_absolute()
class TestProviderChannelInfo:
"""Tests for provider and channel info retrieval."""
def test_get_provider_names_returns_dict(self):
from nanobot.cli.onboard_wizard import _get_provider_names
names = _get_provider_names()
assert isinstance(names, dict)
assert len(names) > 0
# Should include common providers
assert "openai" in names or "anthropic" in names
assert "openai_codex" not in names
assert "github_copilot" not in names
def test_get_channel_names_returns_dict(self):
from nanobot.cli.onboard_wizard import _get_channel_names
names = _get_channel_names()
assert isinstance(names, dict)
# Should include at least some channels
assert len(names) >= 0
def test_get_provider_info_returns_valid_structure(self):
from nanobot.cli.onboard_wizard import _get_provider_info
info = _get_provider_info()
assert isinstance(info, dict)
# Each value should be a tuple with expected structure
for provider_name, value in info.items():
assert isinstance(value, tuple)
assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var)
class _SimpleDraftModel(BaseModel):
api_key: str = ""
class _NestedDraftModel(BaseModel):
api_key: str = ""
class _OuterDraftModel(BaseModel):
nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel)
class TestConfigurePydanticModelDrafts:
@staticmethod
def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"):
sequence = iter(tokens)
def fake_select(_prompt, choices, default=None):
token = next(sequence)
if token == "first":
return choices[0]
if token == "done":
return "[Done]"
if token == "back":
return _BACK_PRESSED
return token
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value
)
def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch):
model = _SimpleDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "back"])
result = _configure_pydantic_model(model, "Simple")
assert result is None
assert model.api_key == ""
def test_completing_section_returns_updated_draft(self, monkeypatch):
model = _SimpleDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "done"])
result = _configure_pydantic_model(model, "Simple")
assert result is not None
updated = cast(_SimpleDraftModel, result)
assert updated.api_key == "secret"
assert model.api_key == ""
def test_nested_section_back_discards_nested_edits(self, monkeypatch):
model = _OuterDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"])
result = _configure_pydantic_model(model, "Outer")
assert result is not None
updated = cast(_OuterDraftModel, result)
assert updated.nested.api_key == ""
assert model.nested.api_key == ""
def test_nested_section_done_commits_nested_edits(self, monkeypatch):
model = _OuterDraftModel()
self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"])
result = _configure_pydantic_model(model, "Outer")
assert result is not None
updated = cast(_OuterDraftModel, result)
assert updated.nested.api_key == "secret"
assert model.nested.api_key == ""
class TestRunOnboardExitBehavior:
def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch):
initial_config = Config()
responses = iter(
[
"[A] Agent Settings",
KeyboardInterrupt(),
"[X] Exit Without Saving",
]
)
class FakePrompt:
def __init__(self, response):
self.response = response
def ask(self):
if isinstance(self.response, BaseException):
raise self.response
return self.response
def fake_select(*_args, **_kwargs):
return FakePrompt(next(responses))
def fake_configure_general_settings(config, section):
if section == "Agent Settings":
config.agents.defaults.model = "test/provider-model"
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
monkeypatch.setattr(onboard_wizard, "_configure_general_settings", fake_configure_general_settings)
result = run_onboard(initial_config=initial_config)
assert result.should_save is False
assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)

View File

@@ -65,6 +65,18 @@ class TestRestartCommand:
mock_handle.assert_called_once() mock_handle.assert_called_once()
@pytest.mark.asyncio
async def test_run_propagates_external_cancellation(self):
"""External task cancellation should not be swallowed by the inbound wait loop."""
loop, _bus = _make_loop()
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
run_task.cancel()
with pytest.raises(asyncio.CancelledError):
await asyncio.wait_for(run_task, timeout=1.0)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_help_includes_restart(self): async def test_help_includes_restart(self):
loop, bus = _make_loop() loop, bus = _make_loop()

View File

@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
def _make_loop(): def _make_loop(*, exec_config=None):
"""Create a minimal AgentLoop with mocked dependencies.""" """Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -23,7 +23,7 @@ def _make_loop():
patch("nanobot.agent.loop.SessionManager"), \ patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace) loop = AgentLoop(bus=bus, provider=provider, workspace=workspace, exec_config=exec_config)
return loop, bus return loop, bus
@@ -90,6 +90,13 @@ class TestHandleStop:
class TestDispatch: class TestDispatch:
def test_exec_tool_not_registered_when_disabled(self):
from nanobot.config.schema import ExecToolConfig
loop, _bus = _make_loop(exec_config=ExecToolConfig(enable=False))
assert loop.tools.get("exec") is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dispatch_processes_and_publishes(self): async def test_dispatch_processes_and_publishes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.events import InboundMessage, OutboundMessage

View File

@@ -18,6 +18,10 @@ class _FakeHTTPXRequest:
self.kwargs = kwargs self.kwargs = kwargs
self.__class__.instances.append(self) self.__class__.instances.append(self)
@classmethod
def clear(cls) -> None:
cls.instances.clear()
class _FakeUpdater: class _FakeUpdater:
def __init__(self, on_start_polling) -> None: def __init__(self, on_start_polling) -> None:
@@ -144,7 +148,8 @@ def _make_telegram_update(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None: async def test_start_creates_separate_pools_with_proxy(monkeypatch) -> None:
_FakeHTTPXRequest.clear()
config = TelegramConfig( config = TelegramConfig(
enabled=True, enabled=True,
token="123:abc", token="123:abc",
@@ -164,10 +169,106 @@ async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> No
await channel.start() await channel.start()
assert len(_FakeHTTPXRequest.instances) == 1 assert len(_FakeHTTPXRequest.instances) == 2
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy api_req, poll_req = _FakeHTTPXRequest.instances
assert builder.request_value is _FakeHTTPXRequest.instances[0] assert api_req.kwargs["proxy"] == config.proxy
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0] assert poll_req.kwargs["proxy"] == config.proxy
assert api_req.kwargs["connection_pool_size"] == 32
assert poll_req.kwargs["connection_pool_size"] == 4
assert builder.request_value is api_req
assert builder.get_updates_request_value is poll_req
@pytest.mark.asyncio
async def test_start_respects_custom_pool_config(monkeypatch) -> None:
_FakeHTTPXRequest.clear()
config = TelegramConfig(
enabled=True,
token="123:abc",
allow_from=["*"],
connection_pool_size=32,
pool_timeout=10.0,
)
bus = MessageBus()
channel = TelegramChannel(config, bus)
app = _FakeApp(lambda: setattr(channel, "_running", False))
builder = _FakeBuilder(app)
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
monkeypatch.setattr(
"nanobot.channels.telegram.Application",
SimpleNamespace(builder=lambda: builder),
)
await channel.start()
api_req = _FakeHTTPXRequest.instances[0]
poll_req = _FakeHTTPXRequest.instances[1]
assert api_req.kwargs["connection_pool_size"] == 32
assert api_req.kwargs["pool_timeout"] == 10.0
assert poll_req.kwargs["pool_timeout"] == 10.0
@pytest.mark.asyncio
async def test_send_text_retries_on_timeout() -> None:
"""_send_text retries on TimedOut before succeeding."""
from telegram.error import TimedOut
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
call_count = 0
original_send = channel._app.bot.send_message
async def flaky_send(**kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
raise TimedOut()
return await original_send(**kwargs)
channel._app.bot.send_message = flaky_send
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
assert call_count == 3
assert len(channel._app.bot.sent_messages) == 1
@pytest.mark.asyncio
async def test_send_text_gives_up_after_max_retries() -> None:
"""_send_text raises TimedOut after exhausting all retries."""
from telegram.error import TimedOut
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
async def always_timeout(**kwargs):
raise TimedOut()
channel._app.bot.send_message = always_timeout
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
assert channel._app.bot.sent_messages == []
def test_derive_topic_session_key_uses_thread_id() -> None: def test_derive_topic_session_key_uses_thread_id() -> None:

View File

@@ -406,3 +406,64 @@ async def test_exec_timeout_capped_at_max() -> None:
# Should not raise — just clamp to 600 # Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999) result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result assert "Exit code: 0" in result
# --- _resolve_type and nullable param tests ---
def test_resolve_type_simple_string() -> None:
"""Simple string type passes through unchanged."""
assert Tool._resolve_type("string") == "string"
def test_resolve_type_union_with_null() -> None:
"""Union type ['string', 'null'] resolves to 'string'."""
assert Tool._resolve_type(["string", "null"]) == "string"
def test_resolve_type_only_null() -> None:
"""Union type ['null'] resolves to None (no non-null type)."""
assert Tool._resolve_type(["null"]) is None
def test_resolve_type_none_input() -> None:
"""None input passes through as None."""
assert Tool._resolve_type(None) is None
def test_validate_nullable_param_accepts_string() -> None:
"""Nullable string param should accept a string value."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
errors = tool.validate_params({"name": "hello"})
assert errors == []
def test_validate_nullable_param_accepts_none() -> None:
"""Nullable string param should accept None."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
errors = tool.validate_params({"name": None})
assert errors == []
def test_cast_nullable_param_no_crash() -> None:
"""cast_params should not crash on nullable type (the original bug)."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
result = tool.cast_params({"name": "hello"})
assert result["name"] == "hello"
result = tool.cast_params({"name": None})
assert result["name"] is None