Merge remote-tracking branch 'origin/main' into pr-1874

This commit is contained in:
Re-bin
2026-03-11 15:26:39 +00:00
35 changed files with 1562 additions and 707 deletions

1
.gitignore vendored
View File

@@ -20,4 +20,5 @@ __pycache__/
poetry.lock poetry.lock
.pytest_cache/ .pytest_cache/
botpy.log botpy.log
nano.*.save

View File

@@ -78,6 +78,25 @@
<img src="nanobot_arch.png" alt="nanobot architecture" width="800"> <img src="nanobot_arch.png" alt="nanobot architecture" width="800">
</p> </p>
## Table of Contents
- [News](#-news)
- [Key Features](#key-features-of-nanobot)
- [Architecture](#-architecture)
- [Features](#-features)
- [Install](#-install)
- [Quick Start](#-quick-start)
- [Chat Apps](#-chat-apps)
- [Agent Social Network](#-agent-social-network)
- [Configuration](#-configuration)
- [Multiple Instances](#-multiple-instances)
- [CLI Reference](#-cli-reference)
- [Docker](#-docker)
- [Linux Service](#-linux-service)
- [Project Structure](#-project-structure)
- [Contribute & Roadmap](#-contribute--roadmap)
- [Star History](#-star-history)
## ✨ Features ## ✨ Features
<table align="center"> <table align="center">
@@ -208,6 +227,7 @@ Connect nanobot to your favorite chat platform.
| **Slack** | Bot token + App-Level token | | **Slack** | Bot token + App-Level token |
| **Email** | IMAP/SMTP credentials | | **Email** | IMAP/SMTP credentials |
| **QQ** | App ID + App Secret | | **QQ** | App ID + App Secret |
| **Wecom** | Bot ID + Bot Secret |
<details> <details>
<summary><b>Telegram</b> (Recommended)</summary> <summary><b>Telegram</b> (Recommended)</summary>
@@ -677,6 +697,46 @@ nanobot gateway
</details> </details>
<details>
<summary><b>Wecom (企业微信)</b></summary>
> Here we use [wecom-aibot-sdk-python](https://github.com/chengyongru/wecom_aibot_sdk) (community Python version of the official [@wecom/aibot-node-sdk](https://www.npmjs.com/package/@wecom/aibot-node-sdk)).
>
> Uses **WebSocket** long connection — no public IP required.
**1. Install the optional dependency**
```bash
pip install nanobot-ai[wecom]
```
**2. Create a WeCom AI Bot**
Go to the WeCom admin console → Intelligent Robot → Create Robot → select **API mode** with **long connection**. Copy the Bot ID and Secret.
**3. Configure**
```json
{
"channels": {
"wecom": {
"enabled": true,
"botId": "your_bot_id",
"secret": "your_bot_secret",
"allowFrom": ["your_id"]
}
}
}
```
**4. Run**
```bash
nanobot gateway
```
</details>
## 🌐 Agent Social Network ## 🌐 Agent Social Network
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!** 🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
@@ -718,6 +778,7 @@ Config file: `~/.nanobot/config.json`
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
| `ollama` | LLM (local, Ollama) | — |
| `vllm` | LLM (local, any OpenAI-compatible server) | — | | `vllm` | LLM (local, any OpenAI-compatible server) | — |
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` | | `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` | | `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
@@ -783,6 +844,37 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
</details> </details>
<details>
<summary><b>Ollama (local)</b></summary>
Run a local model with Ollama, then add to config:
**1. Start Ollama** (example):
```bash
ollama run llama3.2
```
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
```json
{
"providers": {
"ollama": {
"apiBase": "http://localhost:11434"
}
},
"agents": {
"defaults": {
"provider": "ollama",
"model": "llama3.2"
}
}
}
```
> `provider: "auto"` also works when `providers.ollama.apiBase` is configured, but setting `"provider": "ollama"` is the clearest option.
</details>
<details> <details>
<summary><b>vLLM (local / OpenAI-compatible)</b></summary> <summary><b>vLLM (local / OpenAI-compatible)</b></summary>

View File

@@ -10,7 +10,7 @@ from typing import Any
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import detect_image_mime from nanobot.utils.helpers import build_assistant_message, detect_image_mime
class ContextBuilder: class ContextBuilder:
@@ -182,12 +182,10 @@ Reply directly with text for conversations. Only use the 'message' tool to send
thinking_blocks: list[dict] | None = None, thinking_blocks: list[dict] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Add an assistant message to the message list.""" """Add an assistant message to the message list."""
msg: dict[str, Any] = {"role": "assistant", "content": content} messages.append(build_assistant_message(
if tool_calls: content,
msg["tool_calls"] = tool_calls tool_calls=tool_calls,
if reasoning_content is not None: reasoning_content=reasoning_content,
msg["reasoning_content"] = reasoning_content thinking_blocks=thinking_blocks,
if thinking_blocks: ))
msg["thinking_blocks"] = thinking_blocks
messages.append(msg)
return messages return messages

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
import weakref
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -13,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger from loguru import logger
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
@@ -53,10 +52,7 @@ class AgentLoop:
workspace: Path, workspace: Path,
model: str | None = None, model: str | None = None,
max_iterations: int = 40, max_iterations: int = 40,
temperature: float = 0.1, context_window_tokens: int = 65_536,
max_tokens: int = 4096,
memory_window: int = 100,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None, exec_config: ExecToolConfig | None = None,
@@ -73,10 +69,7 @@ class AgentLoop:
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.temperature = temperature self.context_window_tokens = context_window_tokens
self.max_tokens = max_tokens
self.memory_window = memory_window
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@@ -91,9 +84,6 @@ class AgentLoop:
workspace=workspace, workspace=workspace,
bus=bus, bus=bus,
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key, brave_api_key=brave_api_key,
web_proxy=web_proxy, web_proxy=web_proxy,
exec_config=self.exec_config, exec_config=self.exec_config,
@@ -105,11 +95,17 @@ class AgentLoop:
self._mcp_stack: AsyncExitStack | None = None self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._processing_lock = asyncio.Lock() self._processing_lock = asyncio.Lock()
self.memory_consolidator = MemoryConsolidator(
workspace=workspace,
provider=provider,
model=self.model,
sessions=self.sessions,
context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions,
)
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
@@ -182,7 +178,7 @@ class AgentLoop:
initial_messages: list[dict], initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None, on_progress: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict]]: ) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop. Returns (final_content, tools_used, messages).""" """Run the agent iteration loop."""
messages = initial_messages messages = initial_messages
iteration = 0 iteration = 0
final_content = None final_content = None
@@ -191,13 +187,12 @@ class AgentLoop:
while iteration < self.max_iterations: while iteration < self.max_iterations:
iteration += 1 iteration += 1
tool_defs = self.tools.get_definitions()
response = await self.provider.chat_with_retry( response = await self.provider.chat_with_retry(
messages=messages, messages=messages,
tools=self.tools.get_definitions(), tools=tool_defs,
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
@@ -334,8 +329,9 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}" key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=self.memory_window) history = session.get_history(max_messages=0)
messages = self.context.build_messages( messages = self.context.build_messages(
history=history, history=history,
current_message=msg.content, channel=channel, chat_id=chat_id, current_message=msg.content, channel=channel, chat_id=chat_id,
@@ -343,6 +339,7 @@ class AgentLoop:
final_content, _, all_msgs = await self._run_agent_loop(messages) final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.") content=final_content or "Background task completed.")
@@ -355,27 +352,20 @@ class AgentLoop:
# Slash commands # Slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
self._consolidating.add(session.key)
try: try:
async with lock: if not await self.memory_consolidator.archive_unconsolidated(session):
snapshot = session.messages[session.last_consolidated:] return OutboundMessage(
if snapshot: channel=msg.channel,
temp = Session(key=session.key) chat_id=msg.chat_id,
temp.messages = list(snapshot) content="Memory archival failed, session not cleared. Please try again.",
if not await self._consolidate_memory(temp, archive_all=True): )
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
except Exception: except Exception:
logger.exception("/new archival failed for {}", session.key) logger.exception("/new archival failed for {}", session.key)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, channel=msg.channel,
chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.", content="Memory archival failed, session not cleared. Please try again.",
) )
finally:
self._consolidating.discard(session.key)
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
@@ -386,30 +376,14 @@ class AgentLoop:
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
unconsolidated = len(session.messages) - session.last_consolidated await self.memory_consolidator.maybe_consolidate_by_tokens(session)
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
self._consolidating.add(session.key)
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
async def _consolidate_and_unlock():
try:
async with lock:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)
_task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(_task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"): if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool): if isinstance(message_tool, MessageTool):
message_tool.start_turn() message_tool.start_turn()
history = session.get_history(max_messages=self.memory_window) history = session.get_history(max_messages=0)
initial_messages = self.context.build_messages( initial_messages = self.context.build_messages(
history=history, history=history,
current_message=msg.content, current_message=msg.content,
@@ -434,6 +408,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None return None
@@ -480,13 +455,6 @@ class AgentLoop:
session.messages.append(entry) session.messages.append(entry)
session.updated_at = datetime.now() session.updated_at = datetime.now()
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
return await MemoryStore(self.workspace).consolidate(
session, self.provider, self.model,
archive_all=archive_all, memory_window=self.memory_window,
)
async def process_direct( async def process_direct(
self, self,
content: str, content: str,

View File

@@ -2,17 +2,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import weakref
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Callable
from loguru import logger from loguru import logger
from nanobot.utils.helpers import ensure_dir from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
if TYPE_CHECKING: if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session from nanobot.session.manager import Session, SessionManager
_SAVE_MEMORY_TOOL = [ _SAVE_MEMORY_TOOL = [
@@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [
"properties": { "properties": {
"history_entry": { "history_entry": {
"type": "string", "type": "string",
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. " "description": "A paragraph summarizing key events/decisions/topics. "
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.", "Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
}, },
"memory_update": { "memory_update": {
@@ -42,6 +44,19 @@ _SAVE_MEMORY_TOOL = [
] ]
def _ensure_text(value: Any) -> str:
"""Normalize tool-call payload values to text for file storage."""
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
"""Normalize provider tool-call arguments to the expected dict shape."""
if isinstance(args, str):
args = json.loads(args)
if isinstance(args, list):
return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
@@ -66,40 +81,27 @@ class MemoryStore:
long_term = self.read_long_term() long_term = self.read_long_term()
return f"## Long-term Memory\n{long_term}" if long_term else "" return f"## Long-term Memory\n{long_term}" if long_term else ""
@staticmethod
def _format_messages(messages: list[dict]) -> str:
lines = []
for message in messages:
if not message.get("content"):
continue
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
lines.append(
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
)
return "\n".join(lines)
async def consolidate( async def consolidate(
self, self,
session: Session, messages: list[dict],
provider: LLMProvider, provider: LLMProvider,
model: str, model: str,
*,
archive_all: bool = False,
memory_window: int = 50,
) -> bool: ) -> bool:
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call. """Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
if not messages:
Returns True on success (including no-op), False on failure. return True
"""
if archive_all:
old_messages = session.messages
keep_count = 0
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
else:
keep_count = memory_window // 2
if len(session.messages) <= keep_count:
return True
if len(session.messages) - session.last_consolidated <= 0:
return True
old_messages = session.messages[session.last_consolidated:-keep_count]
if not old_messages:
return True
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
lines = []
for m in old_messages:
if not m.get("content"):
continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
current_memory = self.read_long_term() current_memory = self.read_long_term()
prompt = f"""Process this conversation and call the save_memory tool with your consolidation. prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
@@ -108,7 +110,7 @@ class MemoryStore:
{current_memory or "(empty)"} {current_memory or "(empty)"}
## Conversation to Process ## Conversation to Process
{chr(10).join(lines)}""" {self._format_messages(messages)}"""
try: try:
response = await provider.chat_with_retry( response = await provider.chat_with_retry(
@@ -124,34 +126,158 @@ class MemoryStore:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping") logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
return False return False
args = response.tool_calls[0].arguments args = _normalize_save_memory_args(response.tool_calls[0].arguments)
# Some providers return arguments as a JSON string instead of dict if args is None:
if isinstance(args, str): logger.warning("Memory consolidation: unexpected save_memory arguments")
args = json.loads(args)
# Some providers return arguments as a list (handle edge case)
if isinstance(args, list):
if args and isinstance(args[0], dict):
args = args[0]
else:
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
return False
if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False return False
if entry := args.get("history_entry"): if entry := args.get("history_entry"):
if not isinstance(entry, str): self.append_history(_ensure_text(entry))
entry = json.dumps(entry, ensure_ascii=False)
self.append_history(entry)
if update := args.get("memory_update"): if update := args.get("memory_update"):
if not isinstance(update, str): update = _ensure_text(update)
update = json.dumps(update, ensure_ascii=False)
if update != current_memory: if update != current_memory:
self.write_long_term(update) self.write_long_term(update)
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count logger.info("Memory consolidation done for {} messages", len(messages))
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
return True return True
except Exception: except Exception:
logger.exception("Memory consolidation failed") logger.exception("Memory consolidation failed")
return False return False
class MemoryConsolidator:
"""Owns consolidation policy, locking, and session offset updates."""
_MAX_CONSOLIDATION_ROUNDS = 5
def __init__(
self,
workspace: Path,
provider: LLMProvider,
model: str,
sessions: SessionManager,
context_window_tokens: int,
build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]],
):
self.store = MemoryStore(workspace)
self.provider = provider
self.model = model
self.sessions = sessions
self.context_window_tokens = context_window_tokens
self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock())
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
"""Archive a selected message chunk into persistent memory."""
return await self.store.consolidate(messages, self.provider, self.model)
def pick_consolidation_boundary(
self,
session: Session,
tokens_to_remove: int,
) -> tuple[int, int] | None:
"""Pick a user-turn boundary that removes enough old prompt tokens."""
start = session.last_consolidated
if start >= len(session.messages) or tokens_to_remove <= 0:
return None
removed_tokens = 0
last_boundary: tuple[int, int] | None = None
for idx in range(start, len(session.messages)):
message = session.messages[idx]
if idx > start and message.get("role") == "user":
last_boundary = (idx, removed_tokens)
if removed_tokens >= tokens_to_remove:
return last_boundary
removed_tokens += estimate_message_tokens(message)
return last_boundary
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view."""
history = session.get_history(max_messages=0)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self._build_messages(
history=history,
current_message="[token-probe]",
channel=channel,
chat_id=chat_id,
)
return estimate_prompt_tokens_chain(
self.provider,
self.model,
probe_messages,
self._get_tool_definitions(),
)
async def archive_unconsolidated(self, session: Session) -> bool:
"""Archive the full unconsolidated tail for /new-style session rollover."""
lock = self.get_lock(session.key)
async with lock:
snapshot = session.messages[session.last_consolidated:]
if not snapshot:
return True
return await self.consolidate_messages(snapshot)
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within half the context window."""
if not session.messages or self.context_window_tokens <= 0:
return
lock = self.get_lock(session.key)
async with lock:
target = self.context_window_tokens // 2
estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0:
return
if estimated < self.context_window_tokens:
logger.debug(
"Token consolidation idle {}: {}/{} via {}",
session.key,
estimated,
self.context_window_tokens,
source,
)
return
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
if estimated <= target:
return
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
if boundary is None:
logger.debug(
"Token consolidation: no safe boundary for {} (round {})",
session.key,
round_num,
)
return
end_idx = boundary[0]
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return
logger.info(
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
round_num,
session.key,
estimated,
self.context_window_tokens,
source,
len(chunk),
)
if not await self.consolidate_messages(chunk):
return
session.last_consolidated = end_idx
self.sessions.save(session)
estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0:
return

View File

@@ -16,6 +16,7 @@ from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.utils.helpers import build_assistant_message
class SubagentManager: class SubagentManager:
@@ -27,9 +28,6 @@ class SubagentManager:
workspace: Path, workspace: Path,
bus: MessageBus, bus: MessageBus,
model: str | None = None, model: str | None = None,
temperature: float = 0.7,
max_tokens: int = 4096,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: "ExecToolConfig | None" = None, exec_config: "ExecToolConfig | None" = None,
@@ -40,9 +38,6 @@ class SubagentManager:
self.workspace = workspace self.workspace = workspace
self.bus = bus self.bus = bus
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.temperature = temperature
self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@@ -127,22 +122,19 @@ class SubagentManager:
messages=messages, messages=messages,
tools=tools.get_definitions(), tools=tools.get_definitions(),
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
# Add assistant message with tool calls
tool_call_dicts = [ tool_call_dicts = [
tc.to_openai_tool_call() tc.to_openai_tool_call()
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages.append({ messages.append(build_assistant_message(
"role": "assistant", response.content or "",
"content": response.content or "", tool_calls=tool_call_dicts,
"tool_calls": tool_call_dicts, reasoning_content=response.reasoning_content,
}) thinking_blocks=response.thinking_blocks,
))
# Execute tools # Execute tools
for tool_call in response.tool_calls: for tool_call in response.tool_calls:

View File

@@ -1,6 +1,9 @@
"""Base channel interface for chat platforms.""" """Base channel interface for chat platforms."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@@ -18,6 +21,8 @@ class BaseChannel(ABC):
""" """
name: str = "base" name: str = "base"
display_name: str = "Base"
transcription_api_key: str = ""
def __init__(self, config: Any, bus: MessageBus): def __init__(self, config: Any, bus: MessageBus):
""" """
@@ -31,6 +36,19 @@ class BaseChannel(ABC):
self.bus = bus self.bus = bus
self._running = False self._running = False
async def transcribe_audio(self, file_path: str | Path) -> str:
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
if not self.transcription_api_key:
return ""
try:
from nanobot.providers.transcription import GroqTranscriptionProvider
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
return await provider.transcribe(file_path)
except Exception as e:
logger.warning("{}: audio transcription failed: {}", self.name, e)
return ""
@abstractmethod @abstractmethod
async def start(self) -> None: async def start(self) -> None:
""" """

View File

@@ -57,6 +57,8 @@ class NanobotDingTalkHandler(CallbackHandler):
content = "" content = ""
if chatbot_msg.text: if chatbot_msg.text:
content = chatbot_msg.text.content.strip() content = chatbot_msg.text.content.strip()
elif chatbot_msg.extensions.get("content", {}).get("recognition"):
content = chatbot_msg.extensions["content"]["recognition"].strip()
if not content: if not content:
content = message.data.get("text", {}).get("content", "").strip() content = message.data.get("text", {}).get("content", "").strip()
@@ -112,6 +114,7 @@ class DingTalkChannel(BaseChannel):
""" """
name = "dingtalk" name = "dingtalk"
display_name = "DingTalk"
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}

View File

@@ -25,6 +25,7 @@ class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket.""" """Discord channel using Gateway websocket."""
name = "discord" name = "discord"
display_name = "Discord"
def __init__(self, config: DiscordConfig, bus: MessageBus): def __init__(self, config: DiscordConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)

View File

@@ -35,6 +35,7 @@ class EmailChannel(BaseChannel):
""" """
name = "email" name = "email"
display_name = "Email"
_IMAP_MONTHS = ( _IMAP_MONTHS = (
"Jan", "Jan",
"Feb", "Feb",

View File

@@ -244,11 +244,11 @@ class FeishuChannel(BaseChannel):
""" """
name = "feishu" name = "feishu"
display_name = "Feishu"
def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""): def __init__(self, config: FeishuConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: FeishuConfig = config self.config: FeishuConfig = config
self.groq_api_key = groq_api_key
self._client: Any = None self._client: Any = None
self._ws_client: Any = None self._ws_client: Any = None
self._ws_thread: threading.Thread | None = None self._ws_thread: threading.Thread | None = None
@@ -928,16 +928,10 @@ class FeishuChannel(BaseChannel):
if file_path: if file_path:
media_paths.append(file_path) media_paths.append(file_path)
# Transcribe audio using Groq Whisper if msg_type == "audio" and file_path:
if msg_type == "audio" and file_path and self.groq_api_key: transcription = await self.transcribe_audio(file_path)
try: if transcription:
from nanobot.providers.transcription import GroqTranscriptionProvider content_text = f"[transcription: {transcription}]"
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
transcription = await transcriber.transcribe(file_path)
if transcription:
content_text = f"[transcription: {transcription}]"
except Exception as e:
logger.warning("Failed to transcribe audio: {}", e)
content_parts.append(content_text) content_parts.append(content_text)

View File

@@ -7,7 +7,6 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config from nanobot.config.schema import Config
@@ -32,123 +31,23 @@ class ChannelManager:
self._init_channels() self._init_channels()
def _init_channels(self) -> None: def _init_channels(self) -> None:
"""Initialize channels based on config.""" """Initialize channels discovered via pkgutil scan."""
from nanobot.channels.registry import discover_channel_names, load_channel_class
# Telegram channel groq_key = self.config.providers.groq.api_key
if self.config.channels.telegram.enabled:
for modname in discover_channel_names():
section = getattr(self.config.channels, modname, None)
if not section or not getattr(section, "enabled", False):
continue
try: try:
from nanobot.channels.telegram import TelegramChannel cls = load_channel_class(modname)
self.channels["telegram"] = TelegramChannel( channel = cls(section, self.bus)
self.config.channels.telegram, channel.transcription_api_key = groq_key
self.bus, self.channels[modname] = channel
groq_api_key=self.config.providers.groq.api_key, logger.info("{} channel enabled", cls.display_name)
)
logger.info("Telegram channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Telegram channel not available: {}", e) logger.warning("{} channel not available: {}", modname, e)
# WhatsApp channel
if self.config.channels.whatsapp.enabled:
try:
from nanobot.channels.whatsapp import WhatsAppChannel
self.channels["whatsapp"] = WhatsAppChannel(
self.config.channels.whatsapp, self.bus
)
logger.info("WhatsApp channel enabled")
except ImportError as e:
logger.warning("WhatsApp channel not available: {}", e)
# Discord channel
if self.config.channels.discord.enabled:
try:
from nanobot.channels.discord import DiscordChannel
self.channels["discord"] = DiscordChannel(
self.config.channels.discord, self.bus
)
logger.info("Discord channel enabled")
except ImportError as e:
logger.warning("Discord channel not available: {}", e)
# Feishu channel
if self.config.channels.feishu.enabled:
try:
from nanobot.channels.feishu import FeishuChannel
self.channels["feishu"] = FeishuChannel(
self.config.channels.feishu, self.bus,
groq_api_key=self.config.providers.groq.api_key,
)
logger.info("Feishu channel enabled")
except ImportError as e:
logger.warning("Feishu channel not available: {}", e)
# Mochat channel
if self.config.channels.mochat.enabled:
try:
from nanobot.channels.mochat import MochatChannel
self.channels["mochat"] = MochatChannel(
self.config.channels.mochat, self.bus
)
logger.info("Mochat channel enabled")
except ImportError as e:
logger.warning("Mochat channel not available: {}", e)
# DingTalk channel
if self.config.channels.dingtalk.enabled:
try:
from nanobot.channels.dingtalk import DingTalkChannel
self.channels["dingtalk"] = DingTalkChannel(
self.config.channels.dingtalk, self.bus
)
logger.info("DingTalk channel enabled")
except ImportError as e:
logger.warning("DingTalk channel not available: {}", e)
# Email channel
if self.config.channels.email.enabled:
try:
from nanobot.channels.email import EmailChannel
self.channels["email"] = EmailChannel(
self.config.channels.email, self.bus
)
logger.info("Email channel enabled")
except ImportError as e:
logger.warning("Email channel not available: {}", e)
# Slack channel
if self.config.channels.slack.enabled:
try:
from nanobot.channels.slack import SlackChannel
self.channels["slack"] = SlackChannel(
self.config.channels.slack, self.bus
)
logger.info("Slack channel enabled")
except ImportError as e:
logger.warning("Slack channel not available: {}", e)
# QQ channel
if self.config.channels.qq.enabled:
try:
from nanobot.channels.qq import QQChannel
self.channels["qq"] = QQChannel(
self.config.channels.qq,
self.bus,
)
logger.info("QQ channel enabled")
except ImportError as e:
logger.warning("QQ channel not available: {}", e)
# Matrix channel
if self.config.channels.matrix.enabled:
try:
from nanobot.channels.matrix import MatrixChannel
self.channels["matrix"] = MatrixChannel(
self.config.channels.matrix,
self.bus,
)
logger.info("Matrix channel enabled")
except ImportError as e:
logger.warning("Matrix channel not available: {}", e)
self._validate_allow_from() self._validate_allow_from()

View File

@@ -37,6 +37,7 @@ except ImportError as e:
) from e ) from e
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_data_dir, get_media_dir from nanobot.config.paths import get_data_dir, get_media_dir
from nanobot.utils.helpers import safe_filename from nanobot.utils.helpers import safe_filename
@@ -146,15 +147,15 @@ class MatrixChannel(BaseChannel):
"""Matrix (Element) channel using long-polling sync.""" """Matrix (Element) channel using long-polling sync."""
name = "matrix" name = "matrix"
display_name = "Matrix"
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False, def __init__(self, config: Any, bus: MessageBus):
workspace: Path | None = None):
super().__init__(config, bus) super().__init__(config, bus)
self.client: AsyncClient | None = None self.client: AsyncClient | None = None
self._sync_task: asyncio.Task | None = None self._sync_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tasks: dict[str, asyncio.Task] = {}
self._restrict_to_workspace = restrict_to_workspace self._restrict_to_workspace = False
self._workspace = workspace.expanduser().resolve() if workspace else None self._workspace: Path | None = None
self._server_upload_limit_bytes: int | None = None self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False self._server_upload_limit_checked = False
@@ -677,7 +678,14 @@ class MatrixChannel(BaseChannel):
parts: list[str] = [] parts: list[str] = []
if isinstance(body := getattr(event, "body", None), str) and body.strip(): if isinstance(body := getattr(event, "body", None), str) and body.strip():
parts.append(body.strip()) parts.append(body.strip())
if marker:
if attachment and attachment.get("type") == "audio":
transcription = await self.transcribe_audio(attachment["path"])
if transcription:
parts.append(f"[transcription: {transcription}]")
else:
parts.append(marker)
elif marker:
parts.append(marker) parts.append(marker)
await self._start_typing_keepalive(room.room_id) await self._start_typing_keepalive(room.room_id)

View File

@@ -216,6 +216,7 @@ class MochatChannel(BaseChannel):
"""Mochat channel using socket.io with fallback polling workers.""" """Mochat channel using socket.io with fallback polling workers."""
name = "mochat" name = "mochat"
display_name = "Mochat"
def __init__(self, config: MochatConfig, bus: MessageBus): def __init__(self, config: MochatConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)

View File

@@ -54,6 +54,7 @@ class QQChannel(BaseChannel):
"""QQ channel using botpy SDK with WebSocket connection.""" """QQ channel using botpy SDK with WebSocket connection."""
name = "qq" name = "qq"
display_name = "QQ"
def __init__(self, config: QQConfig, bus: MessageBus): def __init__(self, config: QQConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)

View File

@@ -0,0 +1,35 @@
"""Auto-discovery for channel modules — no hardcoded registry."""
from __future__ import annotations
import importlib
import pkgutil
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanobot.channels.base import BaseChannel
_INTERNAL = frozenset({"base", "manager", "registry"})
def discover_channel_names() -> list[str]:
"""Return all channel module names by scanning the package (zero imports)."""
import nanobot.channels as pkg
return [
name
for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
if name not in _INTERNAL and not ispkg
]
def load_channel_class(module_name: str) -> type[BaseChannel]:
"""Import *module_name* and return the first BaseChannel subclass found."""
from nanobot.channels.base import BaseChannel as _Base
mod = importlib.import_module(f"nanobot.channels.{module_name}")
for attr in dir(mod):
obj = getattr(mod, attr)
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
return obj
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")

View File

@@ -21,6 +21,7 @@ class SlackChannel(BaseChannel):
"""Slack channel using Socket Mode.""" """Slack channel using Socket Mode."""
name = "slack" name = "slack"
display_name = "Slack"
def __init__(self, config: SlackConfig, bus: MessageBus): def __init__(self, config: SlackConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)

View File

@@ -155,6 +155,7 @@ class TelegramChannel(BaseChannel):
""" """
name = "telegram" name = "telegram"
display_name = "Telegram"
# Commands registered with Telegram's command menu # Commands registered with Telegram's command menu
BOT_COMMANDS = [ BOT_COMMANDS = [
@@ -164,15 +165,9 @@ class TelegramChannel(BaseChannel):
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
] ]
def __init__( def __init__(self, config: TelegramConfig, bus: MessageBus):
self,
config: TelegramConfig,
bus: MessageBus,
groq_api_key: str = "",
):
super().__init__(config, bus) super().__init__(config, bus)
self.config: TelegramConfig = config self.config: TelegramConfig = config
self.groq_api_key = groq_api_key
self._app: Application | None = None self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
@@ -615,11 +610,8 @@ class TelegramChannel(BaseChannel):
media_paths.append(str(file_path)) media_paths.append(str(file_path))
# Handle voice transcription if media_type in ("voice", "audio"):
if media_type == "voice" or media_type == "audio": transcription = await self.transcribe_audio(file_path)
from nanobot.providers.transcription import GroqTranscriptionProvider
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
transcription = await transcriber.transcribe(file_path)
if transcription: if transcription:
logger.info("Transcribed {}: {}...", media_type, transcription[:50]) logger.info("Transcribed {}: {}...", media_type, transcription[:50])
content_parts.append(f"[transcription: {transcription}]") content_parts.append(f"[transcription: {transcription}]")

353
nanobot/channels/wecom.py Normal file
View File

@@ -0,0 +1,353 @@
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
import asyncio
import importlib.util
import os
from collections import OrderedDict
from typing import Any
from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import WecomConfig
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
# Message type display mapping
MSG_TYPE_MAP = {
"image": "[image]",
"voice": "[voice]",
"file": "[file]",
"mixed": "[mixed content]",
}
class WecomChannel(BaseChannel):
"""
WeCom (Enterprise WeChat) channel using WebSocket long connection.
Uses WebSocket to receive events - no public IP or webhook required.
Requires:
- Bot ID and Secret from WeCom AI Bot platform
"""
name = "wecom"
display_name = "WeCom"
def __init__(self, config: WecomConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: WecomConfig = config
self._client: Any = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
self._loop: asyncio.AbstractEventLoop | None = None
self._generate_req_id = None
# Store frame headers for each chat to enable replies
self._chat_frames: dict[str, Any] = {}
async def start(self) -> None:
"""Start the WeCom bot with WebSocket long connection."""
if not WECOM_AVAILABLE:
logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
return
if not self.config.bot_id or not self.config.secret:
logger.error("WeCom bot_id and secret not configured")
return
from wecom_aibot_sdk import WSClient, generate_req_id
self._running = True
self._loop = asyncio.get_running_loop()
self._generate_req_id = generate_req_id
# Create WebSocket client
self._client = WSClient({
"bot_id": self.config.bot_id,
"secret": self.config.secret,
"reconnect_interval": 1000,
"max_reconnect_attempts": -1, # Infinite reconnect
"heartbeat_interval": 30000,
})
# Register event handlers
self._client.on("connected", self._on_connected)
self._client.on("authenticated", self._on_authenticated)
self._client.on("disconnected", self._on_disconnected)
self._client.on("error", self._on_error)
self._client.on("message.text", self._on_text_message)
self._client.on("message.image", self._on_image_message)
self._client.on("message.voice", self._on_voice_message)
self._client.on("message.file", self._on_file_message)
self._client.on("message.mixed", self._on_mixed_message)
self._client.on("event.enter_chat", self._on_enter_chat)
logger.info("WeCom bot starting with WebSocket long connection")
logger.info("No public IP required - using WebSocket to receive events")
# Connect
await self._client.connect_async()
# Keep running until stopped
while self._running:
await asyncio.sleep(1)
async def stop(self) -> None:
"""Stop the WeCom bot."""
self._running = False
if self._client:
await self._client.disconnect()
logger.info("WeCom bot stopped")
async def _on_connected(self, frame: Any) -> None:
"""Handle WebSocket connected event."""
logger.info("WeCom WebSocket connected")
async def _on_authenticated(self, frame: Any) -> None:
"""Handle authentication success event."""
logger.info("WeCom authenticated successfully")
async def _on_disconnected(self, frame: Any) -> None:
"""Handle WebSocket disconnected event."""
reason = frame.body if hasattr(frame, 'body') else str(frame)
logger.warning("WeCom WebSocket disconnected: {}", reason)
async def _on_error(self, frame: Any) -> None:
"""Handle error event."""
logger.error("WeCom error: {}", frame)
async def _on_text_message(self, frame: Any) -> None:
"""Handle text message."""
await self._process_message(frame, "text")
async def _on_image_message(self, frame: Any) -> None:
"""Handle image message."""
await self._process_message(frame, "image")
async def _on_voice_message(self, frame: Any) -> None:
"""Handle voice message."""
await self._process_message(frame, "voice")
async def _on_file_message(self, frame: Any) -> None:
"""Handle file message."""
await self._process_message(frame, "file")
async def _on_mixed_message(self, frame: Any) -> None:
"""Handle mixed content message."""
await self._process_message(frame, "mixed")
async def _on_enter_chat(self, frame: Any) -> None:
"""Handle enter_chat event (user opens chat with bot)."""
try:
# Extract body from WsFrame dataclass or dict
if hasattr(frame, 'body'):
body = frame.body or {}
elif isinstance(frame, dict):
body = frame.get("body", frame)
else:
body = {}
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
if chat_id and self.config.welcome_message:
await self._client.reply_welcome(frame, {
"msgtype": "text",
"text": {"content": self.config.welcome_message},
})
except Exception as e:
logger.error("Error handling enter_chat: {}", e)
async def _process_message(self, frame: Any, msg_type: str) -> None:
"""Process incoming message and forward to bus."""
try:
# Extract body from WsFrame dataclass or dict
if hasattr(frame, 'body'):
body = frame.body or {}
elif isinstance(frame, dict):
body = frame.get("body", frame)
else:
body = {}
# Ensure body is a dict
if not isinstance(body, dict):
logger.warning("Invalid body type: {}", type(body))
return
# Extract message info
msg_id = body.get("msgid", "")
if not msg_id:
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
# Deduplication check
if msg_id in self._processed_message_ids:
return
self._processed_message_ids[msg_id] = None
# Trim cache
while len(self._processed_message_ids) > 1000:
self._processed_message_ids.popitem(last=False)
# Extract sender info from "from" field (SDK format)
from_info = body.get("from", {})
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
# For single chat, chatid is the sender's userid
# For group chat, chatid is provided in body
chat_type = body.get("chattype", "single")
chat_id = body.get("chatid", sender_id)
content_parts = []
if msg_type == "text":
text = body.get("text", {}).get("content", "")
if text:
content_parts.append(text)
elif msg_type == "image":
image_info = body.get("image", {})
file_url = image_info.get("url", "")
aes_key = image_info.get("aeskey", "")
if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "image")
if file_path:
filename = os.path.basename(file_path)
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
else:
content_parts.append("[image: download failed]")
else:
content_parts.append("[image: download failed]")
elif msg_type == "voice":
voice_info = body.get("voice", {})
# Voice message already contains transcribed content from WeCom
voice_content = voice_info.get("content", "")
if voice_content:
content_parts.append(f"[voice] {voice_content}")
else:
content_parts.append("[voice]")
elif msg_type == "file":
file_info = body.get("file", {})
file_url = file_info.get("url", "")
aes_key = file_info.get("aeskey", "")
file_name = file_info.get("name", "unknown")
if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
if file_path:
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
else:
content_parts.append(f"[file: {file_name}: download failed]")
else:
content_parts.append(f"[file: {file_name}: download failed]")
elif msg_type == "mixed":
# Mixed content contains multiple message items
msg_items = body.get("mixed", {}).get("item", [])
for item in msg_items:
item_type = item.get("type", "")
if item_type == "text":
text = item.get("text", {}).get("content", "")
if text:
content_parts.append(text)
else:
content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
else:
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
content = "\n".join(content_parts) if content_parts else ""
if not content:
return
# Store frame for this chat to enable replies
self._chat_frames[chat_id] = frame
# Forward to message bus
# Note: media paths are included in content for broader model compatibility
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=content,
media=None,
metadata={
"message_id": msg_id,
"msg_type": msg_type,
"chat_type": chat_type,
}
)
except Exception as e:
logger.error("Error processing WeCom message: {}", e)
async def _download_and_save_media(
self,
file_url: str,
aes_key: str,
media_type: str,
filename: str | None = None,
) -> str | None:
"""
Download and decrypt media from WeCom.
Returns:
file_path or None if download failed
"""
try:
data, fname = await self._client.download_file(file_url, aes_key)
if not data:
logger.warning("Failed to download media from WeCom")
return None
media_dir = get_media_dir("wecom")
if not filename:
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
filename = os.path.basename(filename)
file_path = media_dir / filename
file_path.write_bytes(data)
logger.debug("Downloaded {} to {}", media_type, file_path)
return str(file_path)
except Exception as e:
logger.error("Error downloading media: {}", e)
return None
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WeCom."""
if not self._client:
logger.warning("WeCom client not initialized")
return
try:
content = msg.content.strip()
if not content:
return
# Get the stored frame for this chat
frame = self._chat_frames.get(msg.chat_id)
if not frame:
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
return
# Use streaming reply for better UX
stream_id = self._generate_req_id("stream")
# Send as streaming message with finish=True
await self._client.reply_stream(
frame,
stream_id,
content,
finish=True,
)
logger.debug("WeCom message sent to {}", msg.chat_id)
except Exception as e:
logger.error("Error sending WeCom message: {}", e)

View File

@@ -22,6 +22,7 @@ class WhatsAppChannel(BaseChannel):
""" """
name = "whatsapp" name = "whatsapp"
display_name = "WhatsApp"
def __init__(self, config: WhatsAppConfig, bus: MessageBus): def __init__(self, config: WhatsAppConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)

View File

@@ -191,6 +191,8 @@ def onboard():
save_config(Config()) save_config(Config())
console.print(f"[green]✓[/green] Created config at {config_path}") console.print(f"[green]✓[/green] Created config at {config_path}")
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
# Create workspace # Create workspace
workspace = get_workspace_path() workspace = get_workspace_path()
@@ -213,6 +215,7 @@ def onboard():
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
@@ -222,46 +225,50 @@ def _make_provider(config: Config):
# OpenAI Codex (OAuth) # OpenAI Codex (OAuth)
if provider_name == "openai_codex" or model.startswith("openai-codex/"): if provider_name == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(default_model=model) provider = OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
from nanobot.providers.custom_provider import CustomProvider elif provider_name == "custom":
if provider_name == "custom": from nanobot.providers.custom_provider import CustomProvider
return CustomProvider( provider = CustomProvider(
api_key=p.api_key if p else "no-key", api_key=p.api_key if p else "no-key",
api_base=config.get_api_base(model) or "http://localhost:8000/v1", api_base=config.get_api_base(model) or "http://localhost:8000/v1",
default_model=model, default_model=model,
) )
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name # Azure OpenAI: direct Azure OpenAI endpoint with deployment name
if provider_name == "azure_openai": elif provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base: if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.") console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1) raise typer.Exit(1)
provider = AzureOpenAIProvider(
return AzureOpenAIProvider(
api_key=p.api_key, api_key=p.api_key,
api_base=p.api_base, api_base=p.api_base,
default_model=model, default_model=model,
) )
else:
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1)
provider = LiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=provider_name,
)
from nanobot.providers.litellm_provider import LiteLLMProvider defaults = config.agents.defaults
from nanobot.providers.registry import find_by_name provider.generation = GenerationSettings(
spec = find_by_name(provider_name) temperature=defaults.temperature,
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth): max_tokens=defaults.max_tokens,
console.print("[red]Error: No API key configured.[/red]") reasoning_effort=defaults.reasoning_effort,
console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1)
return LiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=provider_name,
) )
return provider
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config: def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
@@ -283,6 +290,16 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
return loaded return loaded
def _print_deprecated_memory_window_notice(config: Config) -> None:
"""Warn when running with old memoryWindow-only config."""
if config.agents.defaults.should_warn_deprecated_memory_window:
console.print(
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
"`contextWindowTokens`. `memoryWindow` is ignored; run "
"[cyan]nanobot onboard[/cyan] to refresh your config template."
)
# ============================================================================ # ============================================================================
# Gateway / Server # Gateway / Server
# ============================================================================ # ============================================================================
@@ -310,6 +327,7 @@ def gateway(
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
config = _load_runtime_config(config, workspace) config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
port = port if port is not None else config.gateway.port port = port if port is not None else config.gateway.port
console.print(f"{__logo__} Starting nanobot gateway on port {port}...") console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
@@ -328,11 +346,8 @@ def gateway(
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, context_window_tokens=config.agents.defaults.context_window_tokens,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
@@ -494,6 +509,7 @@ def agent(
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
config = _load_runtime_config(config, workspace) config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
@@ -513,11 +529,8 @@ def agent(
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, context_window_tokens=config.agents.defaults.context_window_tokens,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
@@ -670,6 +683,7 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status") @channels_app.command("status")
def channels_status(): def channels_status():
"""Show channel status.""" """Show channel status."""
from nanobot.channels.registry import discover_channel_names, load_channel_class
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
config = load_config() config = load_config()
@@ -677,85 +691,19 @@ def channels_status():
table = Table(title="Channel Status") table = Table(title="Channel Status")
table.add_column("Channel", style="cyan") table.add_column("Channel", style="cyan")
table.add_column("Enabled", style="green") table.add_column("Enabled", style="green")
table.add_column("Configuration", style="yellow")
# WhatsApp for modname in sorted(discover_channel_names()):
wa = config.channels.whatsapp section = getattr(config.channels, modname, None)
table.add_row( enabled = section and getattr(section, "enabled", False)
"WhatsApp", try:
"" if wa.enabled else "", cls = load_channel_class(modname)
wa.bridge_url display = cls.display_name
) except ImportError:
display = modname.title()
dc = config.channels.discord table.add_row(
table.add_row( display,
"Discord", "[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
"" if dc.enabled else "", )
dc.gateway_url
)
# Feishu
fs = config.channels.feishu
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
table.add_row(
"Feishu",
"" if fs.enabled else "",
fs_config
)
# Mochat
mc = config.channels.mochat
mc_base = mc.base_url or "[dim]not configured[/dim]"
table.add_row(
"Mochat",
"" if mc.enabled else "",
mc_base
)
# Telegram
tg = config.channels.telegram
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
table.add_row(
"Telegram",
"" if tg.enabled else "",
tg_config
)
# Slack
slack = config.channels.slack
slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
table.add_row(
"Slack",
"" if slack.enabled else "",
slack_config
)
# DingTalk
dt = config.channels.dingtalk
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
table.add_row(
"DingTalk",
"" if dt.enabled else "",
dt_config
)
# QQ
qq = config.channels.qq
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
table.add_row(
"QQ",
"" if qq.enabled else "",
qq_config
)
# Email
em = config.channels.email
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
table.add_row(
"Email",
"" if em.enabled else "",
em_config
)
console.print(table) console.print(table)

View File

@@ -200,6 +200,14 @@ class QQConfig(Base):
) # Allowed user openids (empty = public access) ) # Allowed user openids (empty = public access)
class WecomConfig(Base):
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
enabled: bool = False
bot_id: str = "" # Bot ID from WeCom AI Bot platform
secret: str = "" # Bot Secret from WeCom AI Bot platform
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
welcome_message: str = "" # Welcome message for enter_chat event
class ChannelsConfig(Base): class ChannelsConfig(Base):
@@ -217,6 +225,7 @@ class ChannelsConfig(Base):
slack: SlackConfig = Field(default_factory=SlackConfig) slack: SlackConfig = Field(default_factory=SlackConfig)
qq: QQConfig = Field(default_factory=QQConfig) qq: QQConfig = Field(default_factory=QQConfig)
matrix: MatrixConfig = Field(default_factory=MatrixConfig) matrix: MatrixConfig = Field(default_factory=MatrixConfig)
wecom: WecomConfig = Field(default_factory=WecomConfig)
class AgentDefaults(Base): class AgentDefaults(Base):
@@ -228,11 +237,18 @@ class AgentDefaults(Base):
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
) )
max_tokens: int = 8192 max_tokens: int = 8192
context_window_tokens: int = 65_536
temperature: float = 0.1 temperature: float = 0.1
max_tool_iterations: int = 40 max_tool_iterations: int = 40
memory_window: int = 100 # Deprecated compatibility field: accepted from old configs but ignored at runtime.
memory_window: int | None = Field(default=None, exclude=True)
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
@property
def should_warn_deprecated_memory_window(self) -> bool:
"""Return True when old memoryWindow is present without contextWindowTokens."""
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
class AgentsConfig(Base): class AgentsConfig(Base):
"""Agent configuration.""" """Agent configuration."""
@@ -265,6 +281,7 @@ class ProvidersConfig(Base):
moonshot: ProviderConfig = Field(default_factory=ProviderConfig) moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
@@ -368,16 +385,25 @@ class Config(BaseSettings):
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and model_prefix and normalized_prefix == spec.name: if p and model_prefix and normalized_prefix == spec.name:
if spec.is_oauth or p.api_key: if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name return p, spec.name
# Match by keyword (order follows PROVIDERS registry) # Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and any(_kw_matches(kw) for kw in spec.keywords): if p and any(_kw_matches(kw) for kw in spec.keywords):
if spec.is_oauth or p.api_key: if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name return p, spec.name
# Fallback: configured local providers can route models without
# provider-specific keywords (for example plain "llama3.2" on Ollama).
for spec in PROVIDERS:
if not spec.is_local:
continue
p = getattr(self.providers, spec.name, None)
if p and p.api_base:
return p, spec.name
# Fallback: gateways first, then others (follows registry order) # Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection # OAuth providers are NOT valid fallbacks — they require explicit model selection
for spec in PROVIDERS: for spec in PROVIDERS:
@@ -404,7 +430,7 @@ class Config(BaseSettings):
return p.api_key if p else None return p.api_key if p else None
def get_api_base(self, model: str | None = None) -> str | None: def get_api_base(self, model: str | None = None) -> str | None:
"""Get API base URL for the given model. Applies default URLs for known gateways.""" """Get API base URL for the given model. Applies default URLs for gateway/local providers."""
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
p, name = self._match_provider(model) p, name = self._match_provider(model)
@@ -415,7 +441,7 @@ class Config(BaseSettings):
# to avoid polluting the global litellm.api_base. # to avoid polluting the global litellm.api_base.
if name: if name:
spec = find_by_name(name) spec = find_by_name(name)
if spec and spec.is_gateway and spec.default_api_base: if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
return spec.default_api_base return spec.default_api_base
return None return None

View File

@@ -51,6 +51,21 @@ class LLMResponse:
return len(self.tool_calls) > 0 return len(self.tool_calls) > 0
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation parameters for LLM calls.
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
temperature: float = 0.7
max_tokens: int = 4096
reasoning_effort: str | None = None
class LLMProvider(ABC): class LLMProvider(ABC):
""" """
Abstract base class for LLM providers. Abstract base class for LLM providers.
@@ -75,9 +90,12 @@ class LLMProvider(ABC):
"temporarily unavailable", "temporarily unavailable",
) )
_SENTINEL = object()
def __init__(self, api_key: str | None = None, api_base: str | None = None): def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
self.generation: GenerationSettings = GenerationSettings()
@staticmethod @staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -174,11 +192,23 @@ class LLMProvider(ABC):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: object = _SENTINEL,
temperature: float = 0.7, temperature: object = _SENTINEL,
reasoning_effort: str | None = None, reasoning_effort: object = _SENTINEL,
) -> LLMResponse: ) -> LLMResponse:
"""Call chat() with retry on transient provider failures.""" """Call chat() with retry on transient provider failures.
Parameters default to ``self.generation`` when not explicitly passed,
so callers no longer need to thread temperature / max_tokens /
reasoning_effort through every layer.
"""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try: try:
response = await self.chat( response = await self.chat(

View File

@@ -360,6 +360,23 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Ollama (local, OpenAI-compatible) ===================================
ProviderSpec(
name="ollama",
keywords=("ollama", "nemotron"),
env_key="OLLAMA_API_KEY",
display_name="Ollama",
litellm_prefix="ollama_chat", # model → ollama_chat/model
skip_prefixes=("ollama/", "ollama_chat/"),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="11434",
default_api_base="http://localhost:11434",
strip_model_prefix=False,
model_overrides=(),
),
# === Auxiliary (not a primary LLM provider) ============================ # === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM. # Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.

View File

@@ -1,8 +1,12 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
import json
import re import re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any
import tiktoken
def detect_image_mime(data: bytes) -> str | None: def detect_image_mime(data: bytes) -> str | None:
@@ -68,6 +72,104 @@ def split_message(content: str, max_len: int = 2000) -> list[str]:
return chunks return chunks
def build_assistant_message(
content: str | None,
tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None,
thinking_blocks: list[dict] | None = None,
) -> dict[str, Any]:
"""Build a provider-safe assistant message with optional reasoning fields."""
msg: dict[str, Any] = {"role": "assistant", "content": content}
if tool_calls:
msg["tool_calls"] = tool_calls
if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content
if thinking_blocks:
msg["thinking_blocks"] = thinking_blocks
return msg
def estimate_prompt_tokens(
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> int:
"""Estimate prompt tokens with tiktoken."""
try:
enc = tiktoken.get_encoding("cl100k_base")
parts: list[str] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
if tools:
parts.append(json.dumps(tools, ensure_ascii=False))
return len(enc.encode("\n".join(parts)))
except Exception:
return 0
def estimate_message_tokens(message: dict[str, Any]) -> int:
"""Estimate prompt tokens contributed by one persisted message."""
content = message.get("content")
parts: list[str] = []
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text = part.get("text", "")
if text:
parts.append(text)
else:
parts.append(json.dumps(part, ensure_ascii=False))
elif content is not None:
parts.append(json.dumps(content, ensure_ascii=False))
for key in ("name", "tool_call_id"):
value = message.get(key)
if isinstance(value, str) and value:
parts.append(value)
if message.get("tool_calls"):
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
payload = "\n".join(parts)
if not payload:
return 1
try:
enc = tiktoken.get_encoding("cl100k_base")
return max(1, len(enc.encode(payload)))
except Exception:
return max(1, len(payload) // 4)
def estimate_prompt_tokens_chain(
provider: Any,
model: str | None,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> tuple[int, str]:
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
if callable(provider_counter):
try:
tokens, source = provider_counter(messages, tools, model)
if isinstance(tokens, (int, float)) and tokens > 0:
return int(tokens), str(source or "provider_counter")
except Exception:
pass
estimated = estimate_prompt_tokens(messages, tools)
if estimated > 0:
return int(estimated), "tiktoken"
return 0, "none"
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files.""" """Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files from importlib.resources import files as pkg_files
@@ -88,7 +190,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
added.append(str(dest.relative_to(workspace))) added.append(str(dest.relative_to(workspace)))
for item in tpl.iterdir(): for item in tpl.iterdir():
if item.name.endswith(".md"): if item.name.endswith(".md") and not item.name.startswith("."):
_write(item, workspace / item.name) _write(item, workspace / item.name)
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md") _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
_write(None, workspace / "memory" / "HISTORY.md") _write(None, workspace / "memory" / "HISTORY.md")

View File

@@ -18,7 +18,7 @@ classifiers = [
dependencies = [ dependencies = [
"typer>=0.20.0,<1.0.0", "typer>=0.20.0,<1.0.0",
"litellm>=1.81.5,<2.0.0", "litellm>=1.82.1,<2.0.0",
"pydantic>=2.12.0,<3.0.0", "pydantic>=2.12.0,<3.0.0",
"pydantic-settings>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0",
"websockets>=16.0,<17.0", "websockets>=16.0,<17.0",
@@ -44,9 +44,13 @@ dependencies = [
"json-repair>=0.57.0,<1.0.0", "json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0", "chardet>=3.0.2,<6.0.0",
"openai>=2.8.0", "openai>=2.8.0",
"tiktoken>=0.12.0,<1.0.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
wecom = [
"wecom-aibot-sdk-python @ git+https://github.com/chengyongru/wecom_aibot_sdk.git@v0.1.2",
]
matrix = [ matrix = [
"matrix-nio[e2e]>=0.25.2", "matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0", "mistune>=3.0.0,<4.0.0",
@@ -68,6 +72,9 @@ nanobot = "nanobot.cli.commands:app"
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["nanobot"] packages = ["nanobot"]

View File

@@ -114,6 +114,35 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex" assert config.get_provider_name() == "openai_codex"
def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config()
config.agents.defaults.model = "ollama/llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
config = Config()
config.agents.defaults.provider = "ollama"
config.agents.defaults.model = "llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_auto_detects_ollama_from_local_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex") spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -267,6 +296,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
mock_agent_runtime["config"].agents.defaults.memory_window = 100
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json" config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True) config_file.parent.mkdir(parents=True)
@@ -327,6 +366,29 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
assert seen["workspace"] == override assert seen["workspace"] == override
assert config.workspace_path == override assert config.workspace_path == override
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.memory_window = 100
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json" config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True) config_file.parent.mkdir(parents=True)

View File

@@ -0,0 +1,88 @@
import json
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config
runner = CliRunner()
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 1234,
"memoryWindow": 42,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
assert config.agents.defaults.should_warn_deprecated_memory_window is True
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 2222,
"memoryWindow": 30,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 2222
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 3333,
"memoryWindow": 50,
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "contextWindowTokens" in result.stdout
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 3333
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults

View File

@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
assert_messages_content(old_messages, 10, 34) assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard: class TestNewCommandArchival:
"""Test that consolidation tasks are deduplicated and serialized.""" """Test /new archival behavior with the simplified consolidation flow."""
@pytest.mark.asyncio @staticmethod
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None: def _make_loop(tmp_path: Path):
"""Concurrent messages above memory_window spawn only one consolidation task."""
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop( loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=1,
) )
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
return loop
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
for i in range(5): for i in range(5):
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
loop.sessions.save(session) loop.sessions.save(session)
before_count = len(session.messages) before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool: async def _failing_consolidate(_messages) -> bool:
if archive_all: return False
return False
return True
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg) response = await loop._process_message(new_msg)
assert response is not None assert response is not None
assert "failed" in response.content.lower() assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test") assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task( async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
for i in range(15): for i in range(15):
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session) loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1 archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool: async def _fake_consolidate(messages) -> bool:
nonlocal archived_count nonlocal archived_count
if archive_all: archived_count = len(messages)
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
return True return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg)) response = await loop._process_message(new_msg)
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
assert response is not None assert response is not None
assert "new session started" in response.content.lower() assert "new session started" in response.content.lower()
assert archived_count == 3, ( assert archived_count == 3
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None: async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
for i in range(3): for i in range(3):
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
loop.sessions.save(session) loop.sessions.save(session)
async def _ok_consolidate(sess, archive_all: bool = False) -> bool: async def _ok_consolidate(_messages) -> bool:
return True return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg) response = await loop._process_message(new_msg)

View File

@@ -1,9 +1,11 @@
import asyncio
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.dingtalk import DingTalkChannel import nanobot.channels.dingtalk as dingtalk_module
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
from nanobot.config.schema import DingTalkConfig from nanobot.config.schema import DingTalkConfig
@@ -64,3 +66,46 @@ async def test_group_send_uses_group_messages_api() -> None:
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send" assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123" assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown" assert call["json"]["msgKey"] == "sampleMarkdown"
@pytest.mark.asyncio
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeChatbotMessage:
text = None
extensions = {"content": {"recognition": "voice transcript"}}
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "audio"
@staticmethod
def from_dict(_data):
return _FakeChatbotMessage()
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "2",
"conversationId": "conv123",
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert msg.content == "voice transcript"
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"

View File

@@ -0,0 +1,190 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
import nanobot.agent.memory as memory_module
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
)
loop.tools.get_definitions = MagicMock(return_value=[])
return loop
@pytest.mark.asyncio
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
@pytest.mark.asyncio
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
await loop.process_direct("hello", session_key="cli:test")
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
@pytest.mark.asyncio
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
]
loop.sessions.save(session)
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
assert session.last_consolidated == 4
@pytest.mark.asyncio
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (300, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
"""Once triggered, consolidation should continue until it drops below half threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (150, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
"""Verify preflight consolidation runs before the LLM call in process_direct."""
order: list[str] = []
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
order.append("consolidate")
return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
async def track_llm(*args, **kwargs):
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
return (1000 if call_count[0] <= 1 else 80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")

View File

@@ -7,7 +7,7 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
import json import json
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock
import pytest import pytest
@@ -15,15 +15,12 @@ from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50): def _make_messages(message_count: int = 30):
"""Create a mock session with messages.""" """Create a list of mock messages."""
session = MagicMock() return [
session.messages = [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"} {"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count) for i in range(message_count)
] ]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update): def _make_tool_response(history_entry, memory_update):
@@ -74,9 +71,9 @@ class TestMemoryConsolidationTypeHandling:
) )
) )
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert store.history_file.exists() assert store.history_file.exists()
@@ -95,9 +92,9 @@ class TestMemoryConsolidationTypeHandling:
) )
) )
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert store.history_file.exists() assert store.history_file.exists()
@@ -131,9 +128,9 @@ class TestMemoryConsolidationTypeHandling:
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert "User discussed testing." in store.history_file.read_text() assert "User discussed testing." in store.history_file.read_text()
@@ -147,22 +144,22 @@ class TestMemoryConsolidationTypeHandling:
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
) )
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is False assert result is False
assert not store.history_file.exists() assert not store.history_file.exists()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None: async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count.""" """Consolidation should be a no-op when the selected chunk is empty."""
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=10) messages: list[dict] = []
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
provider.chat.assert_not_called() provider.chat.assert_not_called()
@@ -189,9 +186,9 @@ class TestMemoryConsolidationTypeHandling:
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert "User discussed testing." in store.history_file.read_text() assert "User discussed testing." in store.history_file.read_text()
@@ -215,9 +212,9 @@ class TestMemoryConsolidationTypeHandling:
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is False assert result is False
@@ -239,9 +236,9 @@ class TestMemoryConsolidationTypeHandling:
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is False assert result is False
@@ -255,7 +252,7 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.", memory_update="# Memory\nUser likes testing.",
), ),
]) ])
session = _make_session(message_count=60) messages = _make_messages(message_count=60)
delays: list[int] = [] delays: list[int] = []
async def _fake_sleep(delay: int) -> None: async def _fake_sleep(delay: int) -> None:
@@ -263,8 +260,31 @@ class TestMemoryConsolidationTypeHandling:
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert provider.calls == 2 assert provider.calls == 2
assert delays == [1] assert delays == [1]
@pytest.mark.asyncio
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
"""Consolidation no longer passes generation params — the provider owns them."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat_with_retry.assert_awaited_once()
_, kwargs = provider.chat_with_retry.await_args
assert kwargs["model"] == "test-model"
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs

View File

@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic: class TestMessageToolSuppressLogic:
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]), LLMResponse(content="Done", tool_calls=[]),
]) ])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = [] sent: list[OutboundMessage] = []
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]), LLMResponse(content="I've sent the email.", tool_calls=[]),
]) ])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = [] sent: list[OutboundMessage] = []
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
), ),
LLMResponse(content="Done", tool_calls=[]), LLMResponse(content="Done", tool_calls=[]),
]) ])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok") loop.tools.execute = AsyncMock(return_value="ok")

View File

@@ -2,7 +2,7 @@ import asyncio
import pytest import pytest
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider): class ScriptedProvider(LLMProvider):
@@ -10,9 +10,11 @@ class ScriptedProvider(LLMProvider):
super().__init__() super().__init__()
self._responses = list(responses) self._responses = list(responses)
self.calls = 0 self.calls = 0
self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse: async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1 self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0) response = self._responses.pop(0)
if isinstance(response, BaseException): if isinstance(response, BaseException):
raise response raise response
@@ -90,3 +92,34 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert provider.last_kwargs["temperature"] == 0.2
assert provider.last_kwargs["max_tokens"] == 321
assert provider.last_kwargs["reasoning_effort"] == "high"
@pytest.mark.asyncio
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
"""Explicit kwargs should override provider.generation defaults."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
temperature=0.9,
max_tokens=9999,
reasoning_effort="low",
)
assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low"

View File

@@ -165,3 +165,46 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
assert await mgr.cancel_by_session("nonexistent") == 0 assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def scripted_chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
async def fake_execute(self, name, arguments):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
assistant_messages = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
]
assert len(assistant_messages) == 1
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]