Merge branch 'main' into pr-1920

This commit is contained in:
Xubin Ren
2026-03-13 04:20:14 +00:00
14 changed files with 545 additions and 55 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Test Suite
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libolm-dev build-essential
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Run tests
run: python -m pytest tests/ -v

3
.gitignore vendored
View File

@@ -1,5 +1,6 @@
.worktrees/
.assets
.docs
.env
*.pyc
dist/
@@ -7,7 +8,7 @@ build/
docs/
*.egg-info/
*.egg
*.pyc
*.pycs
*.pyo
*.pyd
*.pyw

View File

@@ -758,15 +758,17 @@ Config file: `~/.nanobot/config.json`
> [!TIP]
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
> - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key |
|----------|---------|-------------|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
@@ -776,7 +778,6 @@ Config file: `~/.nanobot/config.json`
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |

View File

@@ -139,7 +139,7 @@ class AgentLoop:
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True
except Exception as e:
except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack:
try:
@@ -292,7 +292,9 @@ class AgentLoop:
async def _do_restart():
await asyncio.sleep(1)
os.execv(sys.executable, [sys.executable] + sys.argv)
# Use -m nanobot instead of sys.argv[0] for Windows compatibility
# (sys.argv[0] may be just "nanobot" without full path on Windows)
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
asyncio.create_task(_do_restart())

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import json
import weakref
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
@@ -57,13 +58,30 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None
_TOOL_CHOICE_ERROR_MARKERS = (
"tool_choice",
"toolchoice",
"does not support",
'should be ["none", "auto"]',
)
def _is_tool_choice_unsupported(content: str | None) -> bool:
"""Detect provider errors caused by forced tool_choice being unsupported."""
text = (content or "").lower()
return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS)
class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
def __init__(self, workspace: Path):
self.memory_dir = ensure_dir(workspace / "memory")
self.memory_file = self.memory_dir / "MEMORY.md"
self.history_file = self.memory_dir / "HISTORY.md"
self._consecutive_failures = 0
def read_long_term(self) -> str:
if self.memory_file.exists():
@@ -112,38 +130,93 @@ class MemoryStore:
## Conversation to Process
{self._format_messages(messages)}"""
try:
response = await provider.chat_with_retry(
messages=[
chat_messages = [
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
],
]
try:
forced = {"type": "function", "function": {"name": "save_memory"}}
response = await provider.chat_with_retry(
messages=chat_messages,
tools=_SAVE_MEMORY_TOOL,
model=model,
tool_choice="required",
tool_choice=forced,
)
if response.finish_reason == "error" and _is_tool_choice_unsupported(
response.content
):
logger.warning("Forced tool_choice unsupported, retrying with auto")
response = await provider.chat_with_retry(
messages=chat_messages,
tools=_SAVE_MEMORY_TOOL,
model=model,
tool_choice="auto",
)
if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
return False
logger.warning(
"Memory consolidation: LLM did not call save_memory "
"(finish_reason={}, content_len={}, content_preview={})",
response.finish_reason,
len(response.content or ""),
(response.content or "")[:200],
)
return self._fail_or_raw_archive(messages)
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
if args is None:
logger.warning("Memory consolidation: unexpected save_memory arguments")
return False
return self._fail_or_raw_archive(messages)
if entry := args.get("history_entry"):
self.append_history(_ensure_text(entry))
if update := args.get("memory_update"):
if "history_entry" not in args or "memory_update" not in args:
logger.warning("Memory consolidation: save_memory payload missing required fields")
return self._fail_or_raw_archive(messages)
entry = args["history_entry"]
update = args["memory_update"]
if entry is None or update is None:
logger.warning("Memory consolidation: save_memory payload contains null required fields")
return self._fail_or_raw_archive(messages)
entry = _ensure_text(entry).strip()
if not entry:
logger.warning("Memory consolidation: history_entry is empty after normalization")
return self._fail_or_raw_archive(messages)
self.append_history(entry)
update = _ensure_text(update)
if update != current_memory:
self.write_long_term(update)
self._consecutive_failures = 0
logger.info("Memory consolidation done for {} messages", len(messages))
return True
except Exception:
logger.exception("Memory consolidation failed")
return self._fail_or_raw_archive(messages)
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
"""Increment failure count; after threshold, raw-archive messages and return True."""
self._consecutive_failures += 1
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
return False
self._raw_archive(messages)
self._consecutive_failures = 0
return True
def _raw_archive(self, messages: list[dict]) -> None:
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
self.append_history(
f"[{ts}] [RAW] {len(messages)} messages\n"
f"{self._format_messages(messages)}"
)
logger.warning(
"Memory consolidation degraded: raw-archived {} messages", len(messages)
)
class MemoryConsolidator:

View File

@@ -149,13 +149,22 @@ class MatrixChannel(BaseChannel):
name = "matrix"
display_name = "Matrix"
def __init__(self, config: Any, bus: MessageBus):
def __init__(
self,
config: Any,
bus: MessageBus,
*,
restrict_to_workspace: bool = False,
workspace: str | Path | None = None,
):
super().__init__(config, bus)
self.client: AsyncClient | None = None
self._sync_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._restrict_to_workspace = False
self._workspace: Path | None = None
self._restrict_to_workspace = bool(restrict_to_workspace)
self._workspace = (
Path(workspace).expanduser().resolve(strict=False) if workspace is not None else None
)
self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False

View File

@@ -114,16 +114,16 @@ class QQChannel(BaseChannel):
if msg_type == "group":
await self._client.api.post_group_message(
group_openid=msg.chat_id,
msg_type=2,
markdown={"content": msg.content},
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq,
)
else:
await self._client.api.post_c2c_message(
openid=msg.chat_id,
msg_type=2,
markdown={"content": msg.content},
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq,
)

View File

@@ -19,10 +19,12 @@ if sys.platform == "win32":
pass
import typer
from prompt_toolkit import print_formatted_text
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.application import run_in_terminal
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
@@ -111,8 +113,25 @@ def _init_prompt_session() -> None:
)
def _make_console() -> Console:
return Console(file=sys.stdout)
def _render_interactive_ansi(render_fn) -> str:
"""Render Rich output to ANSI so prompt_toolkit can print it safely."""
ansi_console = Console(
force_terminal=True,
color_system=console.color_system or "standard",
width=console.width,
)
with ansi_console.capture() as capture:
render_fn(ansi_console)
return capture.get()
def _print_agent_response(response: str, render_markdown: bool) -> None:
"""Render assistant response with consistent terminal styling."""
console = _make_console()
content = response or ""
body = Markdown(content) if render_markdown else Text(content)
console.print()
@@ -121,6 +140,34 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
console.print()
async def _print_interactive_line(text: str) -> None:
"""Print async interactive updates with prompt_toolkit-safe Rich styling."""
def _write() -> None:
ansi = _render_interactive_ansi(
lambda c: c.print(f" [dim]↳ {text}[/dim]")
)
print_formatted_text(ANSI(ansi), end="")
await run_in_terminal(_write)
async def _print_interactive_response(response: str, render_markdown: bool) -> None:
"""Print async interactive replies with prompt_toolkit-safe Rich styling."""
def _write() -> None:
content = response or ""
ansi = _render_interactive_ansi(
lambda c: (
c.print(),
c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
c.print(Markdown(content) if render_markdown else Text(content)),
c.print(),
)
)
print_formatted_text(ANSI(ansi), end="")
await run_in_terminal(_write)
def _is_exit_command(command: str) -> bool:
"""Return True when input should end interactive chat."""
return command.lower() in EXIT_COMMANDS
@@ -610,14 +657,15 @@ def agent(
elif ch and not is_tool_hint and not ch.send_progress:
pass
else:
console.print(f" [dim]↳ {msg.content}[/dim]")
await _print_interactive_line(msg.content)
elif not turn_done.is_set():
if msg.content:
turn_response.append(msg.content)
turn_done.set()
elif msg.content:
console.print()
_print_agent_response(msg.content, render_markdown=markdown)
await _print_interactive_response(msg.content, render_markdown=markdown)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:

View File

@@ -276,15 +276,18 @@ class ProvidersConfig(Base):
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -398,12 +401,21 @@ class Config(BaseSettings):
# Fallback: configured local providers can route models without
# provider-specific keywords (for example plain "llama3.2" on Ollama).
# Prefer providers whose detect_by_base_keyword matches the configured api_base
# (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
local_fallback: tuple[ProviderConfig, str] | None = None
for spec in PROVIDERS:
if not spec.is_local:
continue
p = getattr(self.providers, spec.name, None)
if p and p.api_base:
if not (p and p.api_base):
continue
if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
return p, spec.name
if local_fallback is None:
local_fallback = (p, spec.name)
if local_fallback:
return local_fallback
# Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection

View File

@@ -145,7 +145,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# VolcEngine (火山引擎): OpenAI-compatible gateway
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
@@ -162,6 +163,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
ProviderSpec(
name="volcengine_coding_plan",
keywords=("volcengine-plan",),
env_key="OPENAI_API_KEY",
display_name="VolcEngine Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus: VolcEngine international, pay-per-use models
ProviderSpec(
name="byteplus",
keywords=("byteplus",),
env_key="OPENAI_API_KEY",
display_name="BytePlus",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="bytepluses",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus Coding Plan: same key as byteplus
ProviderSpec(
name="byteplus_coding_plan",
keywords=("byteplus-plan",),
env_key="OPENAI_API_KEY",
display_name="BytePlus Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(

View File

@@ -75,13 +75,6 @@ build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["nanobot"]
[tool.hatch.build.targets.wheel.sources]
"nanobot" = "nanobot"
# Include non-Python files in skills and templates
[tool.hatch.build]
include = [
"nanobot/**/*.py",
@@ -90,6 +83,15 @@ include = [
"nanobot/skills/**/*.sh",
]
[tool.hatch.build.targets.wheel]
packages = ["nanobot"]
[tool.hatch.build.targets.wheel.sources]
"nanobot" = "nanobot"
[tool.hatch.build.targets.wheel.force-include]
"bridge" = "nanobot/bridge"
[tool.hatch.build.targets.sdist]
include = [
"nanobot/",
@@ -98,9 +100,6 @@ include = [
"LICENSE",
]
[tool.hatch.build.targets.wheel.force-include]
"bridge" = "nanobot/bridge"
[tool.ruff]
line-length = 100
target-version = "py311"

View File

@@ -1,3 +1,4 @@
import re
import shutil
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@@ -11,6 +12,12 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
runner = CliRunner()
@@ -143,6 +150,35 @@ def test_config_auto_detects_ollama_from_local_api_base():
assert config.get_api_base() == "http://localhost:11434"
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
"ollama": {"apiBase": "http://localhost:11434"},
},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_falls_back_to_vllm_when_ollama_not_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
},
}
)
assert config.get_provider_name() == "vllm"
assert config.get_api_base() == "http://localhost:8000"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -199,10 +235,11 @@ def test_agent_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["agent", "--help"])
assert result.exit_code == 0
assert "--workspace" in result.stdout
assert "-w" in result.stdout
assert "--config" in result.stdout
assert "-c" in result.stdout
stripped_output = _strip_ansi(result.stdout)
assert "--workspace" in stripped_output
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):

View File

@@ -112,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse(
content=None,
tool_calls=[
@@ -170,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse(
content=None,
tool_calls=[
@@ -242,6 +240,94 @@ class TestMemoryConsolidationTypeHandling:
assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
@@ -288,3 +374,105 @@ class TestMemoryConsolidationTypeHandling:
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs
@pytest.mark.asyncio
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error calling LLM: litellm.BadRequestError: "
"The tool_choice parameter does not support being set to required or object",
finish_reason="error",
tool_calls=[],
)
ok_resp = _make_tool_response(
history_entry="[2026-01-01] Fallback worked.",
memory_update="# Memory\nFallback OK.",
)
call_log: list[dict] = []
async def _tracking_chat(**kwargs):
call_log.append(kwargs)
return error_resp if len(call_log) == 1 else ok_resp
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert len(call_log) == 2
assert isinstance(call_log[0]["tool_choice"], dict)
assert call_log[1]["tool_choice"] == "auto"
assert "Fallback worked." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
"""Forced rejected, auto retry also produces no tool call -> return False."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error: tool_choice must be none or auto",
finish_reason="error",
tool_calls=[],
)
no_tool_resp = LLMResponse(
content="Here is a summary.",
finish_reason="stop",
tool_calls=[],
)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1

View File

@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
@pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
@@ -60,7 +60,37 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123"
assert call["msg_id"] == "msg1"
assert call["msg_seq"] == 2
assert call == {
"group_openid": "group123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.c2c_calls
@pytest.mark.asyncio
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.c2c_calls) == 1
call = channel._client.api.c2c_calls[0]
assert call == {
"openid": "user123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.group_calls