Merge branch 'main' into pr-1920

This commit is contained in:
Xubin Ren
2026-03-13 04:20:14 +00:00
14 changed files with 545 additions and 55 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Test Suite
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libolm-dev build-essential
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Run tests
run: python -m pytest tests/ -v

3
.gitignore vendored
View File

@@ -1,5 +1,6 @@
.worktrees/ .worktrees/
.assets .assets
.docs
.env .env
*.pyc *.pyc
dist/ dist/
@@ -7,7 +8,7 @@ build/
docs/ docs/
*.egg-info/ *.egg-info/
*.egg *.egg
*.pyc *.pycs
*.pyo *.pyo
*.pyd *.pyd
*.pyw *.pyw

View File

@@ -758,15 +758,17 @@ Config file: `~/.nanobot/config.json`
> [!TIP] > [!TIP]
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. > - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config. > - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key | | Provider | Purpose | Get API Key |
|----------|---------|-------------| |----------|---------|-------------|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | | `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | | `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
@@ -776,7 +778,6 @@ Config file: `~/.nanobot/config.json`
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) | | `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | | `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |

View File

@@ -139,7 +139,7 @@ class AgentLoop:
await self._mcp_stack.__aenter__() await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True self._mcp_connected = True
except Exception as e: except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e) logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack: if self._mcp_stack:
try: try:
@@ -292,7 +292,9 @@ class AgentLoop:
async def _do_restart(): async def _do_restart():
await asyncio.sleep(1) await asyncio.sleep(1)
os.execv(sys.executable, [sys.executable] + sys.argv) # Use -m nanobot instead of sys.argv[0] for Windows compatibility
# (sys.argv[0] may be just "nanobot" without full path on Windows)
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
asyncio.create_task(_do_restart()) asyncio.create_task(_do_restart())

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import weakref import weakref
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
@@ -57,13 +58,30 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
return args[0] if args and isinstance(args[0], dict) else None return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None return args if isinstance(args, dict) else None
_TOOL_CHOICE_ERROR_MARKERS = (
"tool_choice",
"toolchoice",
"does not support",
'should be ["none", "auto"]',
)
def _is_tool_choice_unsupported(content: str | None) -> bool:
"""Detect provider errors caused by forced tool_choice being unsupported."""
text = (content or "").lower()
return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS)
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.memory_dir = ensure_dir(workspace / "memory") self.memory_dir = ensure_dir(workspace / "memory")
self.memory_file = self.memory_dir / "MEMORY.md" self.memory_file = self.memory_dir / "MEMORY.md"
self.history_file = self.memory_dir / "HISTORY.md" self.history_file = self.memory_dir / "HISTORY.md"
self._consecutive_failures = 0
def read_long_term(self) -> str: def read_long_term(self) -> str:
if self.memory_file.exists(): if self.memory_file.exists():
@@ -112,38 +130,93 @@ class MemoryStore:
## Conversation to Process ## Conversation to Process
{self._format_messages(messages)}""" {self._format_messages(messages)}"""
chat_messages = [
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
]
try: try:
forced = {"type": "function", "function": {"name": "save_memory"}}
response = await provider.chat_with_retry( response = await provider.chat_with_retry(
messages=[ messages=chat_messages,
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
],
tools=_SAVE_MEMORY_TOOL, tools=_SAVE_MEMORY_TOOL,
model=model, model=model,
tool_choice="required", tool_choice=forced,
) )
if response.finish_reason == "error" and _is_tool_choice_unsupported(
response.content
):
logger.warning("Forced tool_choice unsupported, retrying with auto")
response = await provider.chat_with_retry(
messages=chat_messages,
tools=_SAVE_MEMORY_TOOL,
model=model,
tool_choice="auto",
)
if not response.has_tool_calls: if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping") logger.warning(
return False "Memory consolidation: LLM did not call save_memory "
"(finish_reason={}, content_len={}, content_preview={})",
response.finish_reason,
len(response.content or ""),
(response.content or "")[:200],
)
return self._fail_or_raw_archive(messages)
args = _normalize_save_memory_args(response.tool_calls[0].arguments) args = _normalize_save_memory_args(response.tool_calls[0].arguments)
if args is None: if args is None:
logger.warning("Memory consolidation: unexpected save_memory arguments") logger.warning("Memory consolidation: unexpected save_memory arguments")
return False return self._fail_or_raw_archive(messages)
if entry := args.get("history_entry"): if "history_entry" not in args or "memory_update" not in args:
self.append_history(_ensure_text(entry)) logger.warning("Memory consolidation: save_memory payload missing required fields")
if update := args.get("memory_update"): return self._fail_or_raw_archive(messages)
update = _ensure_text(update)
if update != current_memory:
self.write_long_term(update)
entry = args["history_entry"]
update = args["memory_update"]
if entry is None or update is None:
logger.warning("Memory consolidation: save_memory payload contains null required fields")
return self._fail_or_raw_archive(messages)
entry = _ensure_text(entry).strip()
if not entry:
logger.warning("Memory consolidation: history_entry is empty after normalization")
return self._fail_or_raw_archive(messages)
self.append_history(entry)
update = _ensure_text(update)
if update != current_memory:
self.write_long_term(update)
self._consecutive_failures = 0
logger.info("Memory consolidation done for {} messages", len(messages)) logger.info("Memory consolidation done for {} messages", len(messages))
return True return True
except Exception: except Exception:
logger.exception("Memory consolidation failed") logger.exception("Memory consolidation failed")
return self._fail_or_raw_archive(messages)
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
"""Increment failure count; after threshold, raw-archive messages and return True."""
self._consecutive_failures += 1
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
return False return False
self._raw_archive(messages)
self._consecutive_failures = 0
return True
def _raw_archive(self, messages: list[dict]) -> None:
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
self.append_history(
f"[{ts}] [RAW] {len(messages)} messages\n"
f"{self._format_messages(messages)}"
)
logger.warning(
"Memory consolidation degraded: raw-archived {} messages", len(messages)
)
class MemoryConsolidator: class MemoryConsolidator:

View File

@@ -149,13 +149,22 @@ class MatrixChannel(BaseChannel):
name = "matrix" name = "matrix"
display_name = "Matrix" display_name = "Matrix"
def __init__(self, config: Any, bus: MessageBus): def __init__(
self,
config: Any,
bus: MessageBus,
*,
restrict_to_workspace: bool = False,
workspace: str | Path | None = None,
):
super().__init__(config, bus) super().__init__(config, bus)
self.client: AsyncClient | None = None self.client: AsyncClient | None = None
self._sync_task: asyncio.Task | None = None self._sync_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tasks: dict[str, asyncio.Task] = {}
self._restrict_to_workspace = False self._restrict_to_workspace = bool(restrict_to_workspace)
self._workspace: Path | None = None self._workspace = (
Path(workspace).expanduser().resolve(strict=False) if workspace is not None else None
)
self._server_upload_limit_bytes: int | None = None self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False self._server_upload_limit_checked = False

View File

@@ -114,16 +114,16 @@ class QQChannel(BaseChannel):
if msg_type == "group": if msg_type == "group":
await self._client.api.post_group_message( await self._client.api.post_group_message(
group_openid=msg.chat_id, group_openid=msg.chat_id,
msg_type=2, msg_type=0,
markdown={"content": msg.content}, content=msg.content,
msg_id=msg_id, msg_id=msg_id,
msg_seq=self._msg_seq, msg_seq=self._msg_seq,
) )
else: else:
await self._client.api.post_c2c_message( await self._client.api.post_c2c_message(
openid=msg.chat_id, openid=msg.chat_id,
msg_type=2, msg_type=0,
markdown={"content": msg.content}, content=msg.content,
msg_id=msg_id, msg_id=msg_id,
msg_seq=self._msg_seq, msg_seq=self._msg_seq,
) )

View File

@@ -19,10 +19,12 @@ if sys.platform == "win32":
pass pass
import typer import typer
from prompt_toolkit import print_formatted_text
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.application import run_in_terminal
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
@@ -111,8 +113,25 @@ def _init_prompt_session() -> None:
) )
def _make_console() -> Console:
return Console(file=sys.stdout)
def _render_interactive_ansi(render_fn) -> str:
"""Render Rich output to ANSI so prompt_toolkit can print it safely."""
ansi_console = Console(
force_terminal=True,
color_system=console.color_system or "standard",
width=console.width,
)
with ansi_console.capture() as capture:
render_fn(ansi_console)
return capture.get()
def _print_agent_response(response: str, render_markdown: bool) -> None: def _print_agent_response(response: str, render_markdown: bool) -> None:
"""Render assistant response with consistent terminal styling.""" """Render assistant response with consistent terminal styling."""
console = _make_console()
content = response or "" content = response or ""
body = Markdown(content) if render_markdown else Text(content) body = Markdown(content) if render_markdown else Text(content)
console.print() console.print()
@@ -121,6 +140,34 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
console.print() console.print()
async def _print_interactive_line(text: str) -> None:
"""Print async interactive updates with prompt_toolkit-safe Rich styling."""
def _write() -> None:
ansi = _render_interactive_ansi(
lambda c: c.print(f" [dim]↳ {text}[/dim]")
)
print_formatted_text(ANSI(ansi), end="")
await run_in_terminal(_write)
async def _print_interactive_response(response: str, render_markdown: bool) -> None:
"""Print async interactive replies with prompt_toolkit-safe Rich styling."""
def _write() -> None:
content = response or ""
ansi = _render_interactive_ansi(
lambda c: (
c.print(),
c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
c.print(Markdown(content) if render_markdown else Text(content)),
c.print(),
)
)
print_formatted_text(ANSI(ansi), end="")
await run_in_terminal(_write)
def _is_exit_command(command: str) -> bool: def _is_exit_command(command: str) -> bool:
"""Return True when input should end interactive chat.""" """Return True when input should end interactive chat."""
return command.lower() in EXIT_COMMANDS return command.lower() in EXIT_COMMANDS
@@ -610,14 +657,15 @@ def agent(
elif ch and not is_tool_hint and not ch.send_progress: elif ch and not is_tool_hint and not ch.send_progress:
pass pass
else: else:
console.print(f" [dim]↳ {msg.content}[/dim]") await _print_interactive_line(msg.content)
elif not turn_done.is_set(): elif not turn_done.is_set():
if msg.content: if msg.content:
turn_response.append(msg.content) turn_response.append(msg.content)
turn_done.set() turn_done.set()
elif msg.content: elif msg.content:
console.print() await _print_interactive_response(msg.content, render_markdown=markdown)
_print_agent_response(msg.content, render_markdown=markdown)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -276,15 +276,18 @@ class ProvidersConfig(Base):
deepseek: ProviderConfig = Field(default_factory=ProviderConfig) deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig) groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig) zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问 dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
vllm: ProviderConfig = Field(default_factory=ProviderConfig) vllm: ProviderConfig = Field(default_factory=ProviderConfig)
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
gemini: ProviderConfig = Field(default_factory=ProviderConfig) gemini: ProviderConfig = Field(default_factory=ProviderConfig)
moonshot: ProviderConfig = Field(default_factory=ProviderConfig) moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -398,12 +401,21 @@ class Config(BaseSettings):
# Fallback: configured local providers can route models without # Fallback: configured local providers can route models without
# provider-specific keywords (for example plain "llama3.2" on Ollama). # provider-specific keywords (for example plain "llama3.2" on Ollama).
# Prefer providers whose detect_by_base_keyword matches the configured api_base
# (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
local_fallback: tuple[ProviderConfig, str] | None = None
for spec in PROVIDERS: for spec in PROVIDERS:
if not spec.is_local: if not spec.is_local:
continue continue
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and p.api_base: if not (p and p.api_base):
continue
if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
return p, spec.name return p, spec.name
if local_fallback is None:
local_fallback = (p, spec.name)
if local_fallback:
return local_fallback
# Fallback: gateways first, then others (follows registry order) # Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection # OAuth providers are NOT valid fallbacks — they require explicit model selection

View File

@@ -145,7 +145,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# VolcEngine (火山引擎): OpenAI-compatible gateway
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
ProviderSpec( ProviderSpec(
name="volcengine", name="volcengine",
keywords=("volcengine", "volces", "ark"), keywords=("volcengine", "volces", "ark"),
@@ -162,6 +163,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
ProviderSpec(
name="volcengine_coding_plan",
keywords=("volcengine-plan",),
env_key="OPENAI_API_KEY",
display_name="VolcEngine Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus: VolcEngine international, pay-per-use models
ProviderSpec(
name="byteplus",
keywords=("byteplus",),
env_key="OPENAI_API_KEY",
display_name="BytePlus",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="bytepluses",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus Coding Plan: same key as byteplus
ProviderSpec(
name="byteplus_coding_plan",
keywords=("byteplus-plan",),
env_key="OPENAI_API_KEY",
display_name="BytePlus Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) =============== # === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(

View File

@@ -75,13 +75,6 @@ build-backend = "hatchling.build"
[tool.hatch.metadata] [tool.hatch.metadata]
allow-direct-references = true allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["nanobot"]
[tool.hatch.build.targets.wheel.sources]
"nanobot" = "nanobot"
# Include non-Python files in skills and templates
[tool.hatch.build] [tool.hatch.build]
include = [ include = [
"nanobot/**/*.py", "nanobot/**/*.py",
@@ -90,6 +83,15 @@ include = [
"nanobot/skills/**/*.sh", "nanobot/skills/**/*.sh",
] ]
[tool.hatch.build.targets.wheel]
packages = ["nanobot"]
[tool.hatch.build.targets.wheel.sources]
"nanobot" = "nanobot"
[tool.hatch.build.targets.wheel.force-include]
"bridge" = "nanobot/bridge"
[tool.hatch.build.targets.sdist] [tool.hatch.build.targets.sdist]
include = [ include = [
"nanobot/", "nanobot/",
@@ -98,9 +100,6 @@ include = [
"LICENSE", "LICENSE",
] ]
[tool.hatch.build.targets.wheel.force-include]
"bridge" = "nanobot/bridge"
[tool.ruff] [tool.ruff]
line-length = 100 line-length = 100
target-version = "py311" target-version = "py311"

View File

@@ -1,3 +1,4 @@
import re
import shutil import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -11,6 +12,12 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model from nanobot.providers.registry import find_by_model
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
runner = CliRunner() runner = CliRunner()
@@ -143,6 +150,35 @@ def test_config_auto_detects_ollama_from_local_api_base():
assert config.get_api_base() == "http://localhost:11434" assert config.get_api_base() == "http://localhost:11434"
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
"ollama": {"apiBase": "http://localhost:11434"},
},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_falls_back_to_vllm_when_ollama_not_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
},
}
)
assert config.get_provider_name() == "vllm"
assert config.get_api_base() == "http://localhost:8000"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex") spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -199,10 +235,11 @@ def test_agent_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["agent", "--help"]) result = runner.invoke(app, ["agent", "--help"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "--workspace" in result.stdout stripped_output = _strip_ansi(result.stdout)
assert "-w" in result.stdout assert "--workspace" in stripped_output
assert "--config" in result.stdout assert "-w" in stripped_output
assert "-c" in result.stdout assert "--config" in stripped_output
assert "-c" in stripped_output
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime): def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):

View File

@@ -112,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse( response = LLMResponse(
content=None, content=None,
tool_calls=[ tool_calls=[
@@ -170,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse( response = LLMResponse(
content=None, content=None,
tool_calls=[ tool_calls=[
@@ -242,6 +240,94 @@ class TestMemoryConsolidationTypeHandling:
assert result is False assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None: async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
@@ -288,3 +374,105 @@ class TestMemoryConsolidationTypeHandling:
assert "temperature" not in kwargs assert "temperature" not in kwargs
assert "max_tokens" not in kwargs assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs assert "reasoning_effort" not in kwargs
@pytest.mark.asyncio
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error calling LLM: litellm.BadRequestError: "
"The tool_choice parameter does not support being set to required or object",
finish_reason="error",
tool_calls=[],
)
ok_resp = _make_tool_response(
history_entry="[2026-01-01] Fallback worked.",
memory_update="# Memory\nFallback OK.",
)
call_log: list[dict] = []
async def _tracking_chat(**kwargs):
call_log.append(kwargs)
return error_resp if len(call_log) == 1 else ok_resp
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert len(call_log) == 2
assert isinstance(call_log[0]["tool_choice"], dict)
assert call_log[1]["tool_choice"] == "auto"
assert "Fallback worked." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
"""Forced rejected, auto retry also produces no tool call -> return False."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error: tool_choice must be none or auto",
finish_reason="error",
tool_calls=[],
)
no_tool_resp = LLMResponse(
content="Here is a summary.",
finish_reason="stop",
tool_calls=[],
)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1

View File

@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None: async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient() channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group" channel._chat_type_cache["group123"] = "group"
@@ -60,7 +60,37 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
assert len(channel._client.api.group_calls) == 1 assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0] call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123" assert call == {
assert call["msg_id"] == "msg1" "group_openid": "group123",
assert call["msg_seq"] == 2 "msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.c2c_calls assert not channel._client.api.c2c_calls
@pytest.mark.asyncio
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.c2c_calls) == 1
call = channel._client.api.c2c_calls[0]
assert call == {
"openid": "user123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.group_calls