Merge branch 'main' into pr-1920
This commit is contained in:
33
.github/workflows/ci.yml
vendored
Normal file
33
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
name: Test Suite
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.11", "3.12", "3.13"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install system dependencies
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y libolm-dev build-essential
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: python -m pytest tests/ -v
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
.worktrees/
|
.worktrees/
|
||||||
.assets
|
.assets
|
||||||
|
.docs
|
||||||
.env
|
.env
|
||||||
*.pyc
|
*.pyc
|
||||||
dist/
|
dist/
|
||||||
@@ -7,7 +8,7 @@ build/
|
|||||||
docs/
|
docs/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
*.egg
|
*.egg
|
||||||
*.pyc
|
*.pycs
|
||||||
*.pyo
|
*.pyo
|
||||||
*.pyd
|
*.pyd
|
||||||
*.pyw
|
*.pyw
|
||||||
|
|||||||
@@ -758,15 +758,17 @@ Config file: `~/.nanobot/config.json`
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||||
|
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
|
||||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||||
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
> - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config.
|
||||||
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
|
|
||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
||||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||||
|
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
|
||||||
|
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
||||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||||
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||||
@@ -776,7 +778,6 @@ Config file: `~/.nanobot/config.json`
|
|||||||
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
||||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||||
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
||||||
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
|
|
||||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class AgentLoop:
|
|||||||
await self._mcp_stack.__aenter__()
|
await self._mcp_stack.__aenter__()
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||||
self._mcp_connected = True
|
self._mcp_connected = True
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
if self._mcp_stack:
|
if self._mcp_stack:
|
||||||
try:
|
try:
|
||||||
@@ -292,7 +292,9 @@ class AgentLoop:
|
|||||||
|
|
||||||
async def _do_restart():
|
async def _do_restart():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
os.execv(sys.executable, [sys.executable] + sys.argv)
|
# Use -m nanobot instead of sys.argv[0] for Windows compatibility
|
||||||
|
# (sys.argv[0] may be just "nanobot" without full path on Windows)
|
||||||
|
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
|
||||||
|
|
||||||
asyncio.create_task(_do_restart())
|
asyncio.create_task(_do_restart())
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import weakref
|
import weakref
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
@@ -57,13 +58,30 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
|||||||
return args[0] if args and isinstance(args[0], dict) else None
|
return args[0] if args and isinstance(args[0], dict) else None
|
||||||
return args if isinstance(args, dict) else None
|
return args if isinstance(args, dict) else None
|
||||||
|
|
||||||
|
_TOOL_CHOICE_ERROR_MARKERS = (
|
||||||
|
"tool_choice",
|
||||||
|
"toolchoice",
|
||||||
|
"does not support",
|
||||||
|
'should be ["none", "auto"]',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_choice_unsupported(content: str | None) -> bool:
|
||||||
|
"""Detect provider errors caused by forced tool_choice being unsupported."""
|
||||||
|
text = (content or "").lower()
|
||||||
|
return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS)
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
|
|
||||||
|
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.memory_dir = ensure_dir(workspace / "memory")
|
self.memory_dir = ensure_dir(workspace / "memory")
|
||||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||||
self.history_file = self.memory_dir / "HISTORY.md"
|
self.history_file = self.memory_dir / "HISTORY.md"
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
|
||||||
def read_long_term(self) -> str:
|
def read_long_term(self) -> str:
|
||||||
if self.memory_file.exists():
|
if self.memory_file.exists():
|
||||||
@@ -112,38 +130,93 @@ class MemoryStore:
|
|||||||
## Conversation to Process
|
## Conversation to Process
|
||||||
{self._format_messages(messages)}"""
|
{self._format_messages(messages)}"""
|
||||||
|
|
||||||
|
chat_messages = [
|
||||||
|
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
forced = {"type": "function", "function": {"name": "save_memory"}}
|
||||||
response = await provider.chat_with_retry(
|
response = await provider.chat_with_retry(
|
||||||
messages=[
|
messages=chat_messages,
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
tools=_SAVE_MEMORY_TOOL,
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
model=model,
|
model=model,
|
||||||
tool_choice="required",
|
tool_choice=forced,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response.finish_reason == "error" and _is_tool_choice_unsupported(
|
||||||
|
response.content
|
||||||
|
):
|
||||||
|
logger.warning("Forced tool_choice unsupported, retrying with auto")
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=chat_messages,
|
||||||
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
|
model=model,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
logger.warning(
|
||||||
return False
|
"Memory consolidation: LLM did not call save_memory "
|
||||||
|
"(finish_reason={}, content_len={}, content_preview={})",
|
||||||
|
response.finish_reason,
|
||||||
|
len(response.content or ""),
|
||||||
|
(response.content or "")[:200],
|
||||||
|
)
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
||||||
if args is None:
|
if args is None:
|
||||||
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
||||||
return False
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if "history_entry" not in args or "memory_update" not in args:
|
||||||
self.append_history(_ensure_text(entry))
|
logger.warning("Memory consolidation: save_memory payload missing required fields")
|
||||||
if update := args.get("memory_update"):
|
return self._fail_or_raw_archive(messages)
|
||||||
update = _ensure_text(update)
|
|
||||||
if update != current_memory:
|
|
||||||
self.write_long_term(update)
|
|
||||||
|
|
||||||
|
entry = args["history_entry"]
|
||||||
|
update = args["memory_update"]
|
||||||
|
|
||||||
|
if entry is None or update is None:
|
||||||
|
logger.warning("Memory consolidation: save_memory payload contains null required fields")
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
entry = _ensure_text(entry).strip()
|
||||||
|
if not entry:
|
||||||
|
logger.warning("Memory consolidation: history_entry is empty after normalization")
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
self.append_history(entry)
|
||||||
|
update = _ensure_text(update)
|
||||||
|
if update != current_memory:
|
||||||
|
self.write_long_term(update)
|
||||||
|
|
||||||
|
self._consecutive_failures = 0
|
||||||
logger.info("Memory consolidation done for {} messages", len(messages))
|
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Memory consolidation failed")
|
logger.exception("Memory consolidation failed")
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
|
||||||
|
"""Increment failure count; after threshold, raw-archive messages and return True."""
|
||||||
|
self._consecutive_failures += 1
|
||||||
|
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
|
||||||
return False
|
return False
|
||||||
|
self._raw_archive(messages)
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _raw_archive(self, messages: list[dict]) -> None:
|
||||||
|
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
|
||||||
|
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||||
|
self.append_history(
|
||||||
|
f"[{ts}] [RAW] {len(messages)} messages\n"
|
||||||
|
f"{self._format_messages(messages)}"
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Memory consolidation degraded: raw-archived {} messages", len(messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryConsolidator:
|
class MemoryConsolidator:
|
||||||
|
|||||||
@@ -149,13 +149,22 @@ class MatrixChannel(BaseChannel):
|
|||||||
name = "matrix"
|
name = "matrix"
|
||||||
display_name = "Matrix"
|
display_name = "Matrix"
|
||||||
|
|
||||||
def __init__(self, config: Any, bus: MessageBus):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Any,
|
||||||
|
bus: MessageBus,
|
||||||
|
*,
|
||||||
|
restrict_to_workspace: bool = False,
|
||||||
|
workspace: str | Path | None = None,
|
||||||
|
):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.client: AsyncClient | None = None
|
self.client: AsyncClient | None = None
|
||||||
self._sync_task: asyncio.Task | None = None
|
self._sync_task: asyncio.Task | None = None
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||||
self._restrict_to_workspace = False
|
self._restrict_to_workspace = bool(restrict_to_workspace)
|
||||||
self._workspace: Path | None = None
|
self._workspace = (
|
||||||
|
Path(workspace).expanduser().resolve(strict=False) if workspace is not None else None
|
||||||
|
)
|
||||||
self._server_upload_limit_bytes: int | None = None
|
self._server_upload_limit_bytes: int | None = None
|
||||||
self._server_upload_limit_checked = False
|
self._server_upload_limit_checked = False
|
||||||
|
|
||||||
|
|||||||
@@ -114,16 +114,16 @@ class QQChannel(BaseChannel):
|
|||||||
if msg_type == "group":
|
if msg_type == "group":
|
||||||
await self._client.api.post_group_message(
|
await self._client.api.post_group_message(
|
||||||
group_openid=msg.chat_id,
|
group_openid=msg.chat_id,
|
||||||
msg_type=2,
|
msg_type=0,
|
||||||
markdown={"content": msg.content},
|
content=msg.content,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
msg_seq=self._msg_seq,
|
msg_seq=self._msg_seq,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._client.api.post_c2c_message(
|
await self._client.api.post_c2c_message(
|
||||||
openid=msg.chat_id,
|
openid=msg.chat_id,
|
||||||
msg_type=2,
|
msg_type=0,
|
||||||
markdown={"content": msg.content},
|
content=msg.content,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
msg_seq=self._msg_seq,
|
msg_seq=self._msg_seq,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,10 +19,12 @@ if sys.platform == "win32":
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from prompt_toolkit import print_formatted_text
|
||||||
from prompt_toolkit import PromptSession
|
from prompt_toolkit import PromptSession
|
||||||
from prompt_toolkit.formatted_text import HTML
|
from prompt_toolkit.formatted_text import ANSI, HTML
|
||||||
from prompt_toolkit.history import FileHistory
|
from prompt_toolkit.history import FileHistory
|
||||||
from prompt_toolkit.patch_stdout import patch_stdout
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
|
from prompt_toolkit.application import run_in_terminal
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
@@ -111,8 +113,25 @@ def _init_prompt_session() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_console() -> Console:
|
||||||
|
return Console(file=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_interactive_ansi(render_fn) -> str:
|
||||||
|
"""Render Rich output to ANSI so prompt_toolkit can print it safely."""
|
||||||
|
ansi_console = Console(
|
||||||
|
force_terminal=True,
|
||||||
|
color_system=console.color_system or "standard",
|
||||||
|
width=console.width,
|
||||||
|
)
|
||||||
|
with ansi_console.capture() as capture:
|
||||||
|
render_fn(ansi_console)
|
||||||
|
return capture.get()
|
||||||
|
|
||||||
|
|
||||||
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||||
"""Render assistant response with consistent terminal styling."""
|
"""Render assistant response with consistent terminal styling."""
|
||||||
|
console = _make_console()
|
||||||
content = response or ""
|
content = response or ""
|
||||||
body = Markdown(content) if render_markdown else Text(content)
|
body = Markdown(content) if render_markdown else Text(content)
|
||||||
console.print()
|
console.print()
|
||||||
@@ -121,6 +140,34 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
|
|||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
|
async def _print_interactive_line(text: str) -> None:
|
||||||
|
"""Print async interactive updates with prompt_toolkit-safe Rich styling."""
|
||||||
|
def _write() -> None:
|
||||||
|
ansi = _render_interactive_ansi(
|
||||||
|
lambda c: c.print(f" [dim]↳ {text}[/dim]")
|
||||||
|
)
|
||||||
|
print_formatted_text(ANSI(ansi), end="")
|
||||||
|
|
||||||
|
await run_in_terminal(_write)
|
||||||
|
|
||||||
|
|
||||||
|
async def _print_interactive_response(response: str, render_markdown: bool) -> None:
|
||||||
|
"""Print async interactive replies with prompt_toolkit-safe Rich styling."""
|
||||||
|
def _write() -> None:
|
||||||
|
content = response or ""
|
||||||
|
ansi = _render_interactive_ansi(
|
||||||
|
lambda c: (
|
||||||
|
c.print(),
|
||||||
|
c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
|
||||||
|
c.print(Markdown(content) if render_markdown else Text(content)),
|
||||||
|
c.print(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print_formatted_text(ANSI(ansi), end="")
|
||||||
|
|
||||||
|
await run_in_terminal(_write)
|
||||||
|
|
||||||
|
|
||||||
def _is_exit_command(command: str) -> bool:
|
def _is_exit_command(command: str) -> bool:
|
||||||
"""Return True when input should end interactive chat."""
|
"""Return True when input should end interactive chat."""
|
||||||
return command.lower() in EXIT_COMMANDS
|
return command.lower() in EXIT_COMMANDS
|
||||||
@@ -610,14 +657,15 @@ def agent(
|
|||||||
elif ch and not is_tool_hint and not ch.send_progress:
|
elif ch and not is_tool_hint and not ch.send_progress:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
console.print(f" [dim]↳ {msg.content}[/dim]")
|
await _print_interactive_line(msg.content)
|
||||||
|
|
||||||
elif not turn_done.is_set():
|
elif not turn_done.is_set():
|
||||||
if msg.content:
|
if msg.content:
|
||||||
turn_response.append(msg.content)
|
turn_response.append(msg.content)
|
||||||
turn_done.set()
|
turn_done.set()
|
||||||
elif msg.content:
|
elif msg.content:
|
||||||
console.print()
|
await _print_interactive_response(msg.content, render_markdown=markdown)
|
||||||
_print_agent_response(msg.content, render_markdown=markdown)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
@@ -276,15 +276,18 @@ class ProvidersConfig(Base):
|
|||||||
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
|
dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
|
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
||||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
|
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
|
||||||
|
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
|
||||||
|
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||||
|
|
||||||
@@ -398,12 +401,21 @@ class Config(BaseSettings):
|
|||||||
|
|
||||||
# Fallback: configured local providers can route models without
|
# Fallback: configured local providers can route models without
|
||||||
# provider-specific keywords (for example plain "llama3.2" on Ollama).
|
# provider-specific keywords (for example plain "llama3.2" on Ollama).
|
||||||
|
# Prefer providers whose detect_by_base_keyword matches the configured api_base
|
||||||
|
# (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
|
||||||
|
local_fallback: tuple[ProviderConfig, str] | None = None
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
if not spec.is_local:
|
if not spec.is_local:
|
||||||
continue
|
continue
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and p.api_base:
|
if not (p and p.api_base):
|
||||||
|
continue
|
||||||
|
if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
if local_fallback is None:
|
||||||
|
local_fallback = (p, spec.name)
|
||||||
|
if local_fallback:
|
||||||
|
return local_fallback
|
||||||
|
|
||||||
# Fallback: gateways first, then others (follows registry order)
|
# Fallback: gateways first, then others (follows registry order)
|
||||||
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
||||||
|
|||||||
@@ -145,7 +145,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
|
||||||
|
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="volcengine",
|
name="volcengine",
|
||||||
keywords=("volcengine", "volces", "ark"),
|
keywords=("volcengine", "volces", "ark"),
|
||||||
@@ -162,6 +163,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
|
||||||
|
ProviderSpec(
|
||||||
|
name="volcengine_coding_plan",
|
||||||
|
keywords=("volcengine-plan",),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="VolcEngine Coding Plan",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="",
|
||||||
|
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||||
|
strip_model_prefix=True,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
|
# BytePlus: VolcEngine international, pay-per-use models
|
||||||
|
ProviderSpec(
|
||||||
|
name="byteplus",
|
||||||
|
keywords=("byteplus",),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="BytePlus",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="bytepluses",
|
||||||
|
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
|
||||||
|
strip_model_prefix=True,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
|
# BytePlus Coding Plan: same key as byteplus
|
||||||
|
ProviderSpec(
|
||||||
|
name="byteplus_coding_plan",
|
||||||
|
keywords=("byteplus-plan",),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="BytePlus Coding Plan",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="",
|
||||||
|
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
|
||||||
|
strip_model_prefix=True,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
|
|
||||||
# === Standard providers (matched by model-name keywords) ===============
|
# === Standard providers (matched by model-name keywords) ===============
|
||||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
|
|||||||
@@ -75,13 +75,6 @@ build-backend = "hatchling.build"
|
|||||||
[tool.hatch.metadata]
|
[tool.hatch.metadata]
|
||||||
allow-direct-references = true
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
|
||||||
packages = ["nanobot"]
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel.sources]
|
|
||||||
"nanobot" = "nanobot"
|
|
||||||
|
|
||||||
# Include non-Python files in skills and templates
|
|
||||||
[tool.hatch.build]
|
[tool.hatch.build]
|
||||||
include = [
|
include = [
|
||||||
"nanobot/**/*.py",
|
"nanobot/**/*.py",
|
||||||
@@ -90,6 +83,15 @@ include = [
|
|||||||
"nanobot/skills/**/*.sh",
|
"nanobot/skills/**/*.sh",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["nanobot"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.sources]
|
||||||
|
"nanobot" = "nanobot"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.force-include]
|
||||||
|
"bridge" = "nanobot/bridge"
|
||||||
|
|
||||||
[tool.hatch.build.targets.sdist]
|
[tool.hatch.build.targets.sdist]
|
||||||
include = [
|
include = [
|
||||||
"nanobot/",
|
"nanobot/",
|
||||||
@@ -98,9 +100,6 @@ include = [
|
|||||||
"LICENSE",
|
"LICENSE",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel.force-include]
|
|
||||||
"bridge" = "nanobot/bridge"
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@@ -11,6 +12,12 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
|
|||||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||||
from nanobot.providers.registry import find_by_model
|
from nanobot.providers.registry import find_by_model
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_ansi(text):
|
||||||
|
"""Remove ANSI escape codes from text."""
|
||||||
|
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
|
||||||
|
return ansi_escape.sub('', text)
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
@@ -143,6 +150,35 @@ def test_config_auto_detects_ollama_from_local_api_base():
|
|||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
|
"providers": {
|
||||||
|
"vllm": {"apiBase": "http://localhost:8000"},
|
||||||
|
"ollama": {"apiBase": "http://localhost:11434"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "ollama"
|
||||||
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
|
"providers": {
|
||||||
|
"vllm": {"apiBase": "http://localhost:8000"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "vllm"
|
||||||
|
assert config.get_api_base() == "http://localhost:8000"
|
||||||
|
|
||||||
|
|
||||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||||
|
|
||||||
@@ -199,10 +235,11 @@ def test_agent_help_shows_workspace_and_config_options():
|
|||||||
result = runner.invoke(app, ["agent", "--help"])
|
result = runner.invoke(app, ["agent", "--help"])
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "--workspace" in result.stdout
|
stripped_output = _strip_ansi(result.stdout)
|
||||||
assert "-w" in result.stdout
|
assert "--workspace" in stripped_output
|
||||||
assert "--config" in result.stdout
|
assert "-w" in stripped_output
|
||||||
assert "-c" in result.stdout
|
assert "--config" in stripped_output
|
||||||
|
assert "-c" in stripped_output
|
||||||
|
|
||||||
|
|
||||||
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
|
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
|
|
||||||
# Simulate arguments being a JSON string (not yet parsed)
|
|
||||||
response = LLMResponse(
|
response = LLMResponse(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -170,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
|
|
||||||
# Simulate arguments being a list containing a dict
|
|
||||||
response = LLMResponse(
|
response = LLMResponse(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -242,6 +240,94 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Do not persist partial results when required fields are missing."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={"memory_update": "# Memory\nOnly memory update"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Do not append history if memory_update is missing."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={"history_entry": "[2026-01-01] Partial output."},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Null required fields should be rejected before persistence."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry=None,
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Empty history entries should be rejected to avoid blank archival records."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry=" ",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
@@ -288,3 +374,105 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
assert "temperature" not in kwargs
|
assert "temperature" not in kwargs
|
||||||
assert "max_tokens" not in kwargs
|
assert "max_tokens" not in kwargs
|
||||||
assert "reasoning_effort" not in kwargs
|
assert "reasoning_effort" not in kwargs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
|
||||||
|
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
error_resp = LLMResponse(
|
||||||
|
content="Error calling LLM: litellm.BadRequestError: "
|
||||||
|
"The tool_choice parameter does not support being set to required or object",
|
||||||
|
finish_reason="error",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
ok_resp = _make_tool_response(
|
||||||
|
history_entry="[2026-01-01] Fallback worked.",
|
||||||
|
memory_update="# Memory\nFallback OK.",
|
||||||
|
)
|
||||||
|
|
||||||
|
call_log: list[dict] = []
|
||||||
|
|
||||||
|
async def _tracking_chat(**kwargs):
|
||||||
|
call_log.append(kwargs)
|
||||||
|
return error_resp if len(call_log) == 1 else ok_resp
|
||||||
|
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert len(call_log) == 2
|
||||||
|
assert isinstance(call_log[0]["tool_choice"], dict)
|
||||||
|
assert call_log[1]["tool_choice"] == "auto"
|
||||||
|
assert "Fallback worked." in store.history_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
|
||||||
|
"""Forced rejected, auto retry also produces no tool call -> return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
error_resp = LLMResponse(
|
||||||
|
content="Error: tool_choice must be none or auto",
|
||||||
|
finish_reason="error",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
no_tool_resp = LLMResponse(
|
||||||
|
content="Here is a summary.",
|
||||||
|
finish_reason="stop",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
|
||||||
|
"""After 3 consecutive failures, raw-archive messages and return True."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||||
|
messages = _make_messages(message_count=10)
|
||||||
|
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert await store.consolidate(messages, provider, "m") is True
|
||||||
|
|
||||||
|
assert store.history_file.exists()
|
||||||
|
content = store.history_file.read_text()
|
||||||
|
assert "[RAW]" in content
|
||||||
|
assert "10 messages" in content
|
||||||
|
assert "msg0" in content
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
|
||||||
|
"""A successful consolidation resets the failure counter."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
|
||||||
|
ok_resp = _make_tool_response(
|
||||||
|
history_entry="[2026-01-01] OK.",
|
||||||
|
memory_update="# Memory\nOK.",
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=10)
|
||||||
|
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert store._consecutive_failures == 2
|
||||||
|
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
|
||||||
|
assert await store.consolidate(messages, provider, "m") is True
|
||||||
|
assert store._consecutive_failures == 0
|
||||||
|
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert store._consecutive_failures == 1
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
|
||||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
channel._client = _FakeClient()
|
channel._client = _FakeClient()
|
||||||
channel._chat_type_cache["group123"] = "group"
|
channel._chat_type_cache["group123"] = "group"
|
||||||
@@ -60,7 +60,37 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
|||||||
|
|
||||||
assert len(channel._client.api.group_calls) == 1
|
assert len(channel._client.api.group_calls) == 1
|
||||||
call = channel._client.api.group_calls[0]
|
call = channel._client.api.group_calls[0]
|
||||||
assert call["group_openid"] == "group123"
|
assert call == {
|
||||||
assert call["msg_id"] == "msg1"
|
"group_openid": "group123",
|
||||||
assert call["msg_seq"] == 2
|
"msg_type": 0,
|
||||||
|
"content": "hello",
|
||||||
|
"msg_id": "msg1",
|
||||||
|
"msg_seq": 2,
|
||||||
|
}
|
||||||
assert not channel._client.api.c2c_calls
|
assert not channel._client.api.c2c_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="user123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "msg1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(channel._client.api.c2c_calls) == 1
|
||||||
|
call = channel._client.api.c2c_calls[0]
|
||||||
|
assert call == {
|
||||||
|
"openid": "user123",
|
||||||
|
"msg_type": 0,
|
||||||
|
"content": "hello",
|
||||||
|
"msg_id": "msg1",
|
||||||
|
"msg_seq": 2,
|
||||||
|
}
|
||||||
|
assert not channel._client.api.group_calls
|
||||||
|
|||||||
Reference in New Issue
Block a user