Merge remote-tracking branch 'origin/main' into pr-1874
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,4 +20,5 @@ __pycache__/
|
|||||||
poetry.lock
|
poetry.lock
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
botpy.log
|
botpy.log
|
||||||
|
nano.*.save
|
||||||
|
|
||||||
|
|||||||
92
README.md
92
README.md
@@ -78,6 +78,25 @@
|
|||||||
<img src="nanobot_arch.png" alt="nanobot architecture" width="800">
|
<img src="nanobot_arch.png" alt="nanobot architecture" width="800">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [News](#-news)
|
||||||
|
- [Key Features](#key-features-of-nanobot)
|
||||||
|
- [Architecture](#️-architecture)
|
||||||
|
- [Features](#-features)
|
||||||
|
- [Install](#-install)
|
||||||
|
- [Quick Start](#-quick-start)
|
||||||
|
- [Chat Apps](#-chat-apps)
|
||||||
|
- [Agent Social Network](#-agent-social-network)
|
||||||
|
- [Configuration](#️-configuration)
|
||||||
|
- [Multiple Instances](#-multiple-instances)
|
||||||
|
- [CLI Reference](#-cli-reference)
|
||||||
|
- [Docker](#-docker)
|
||||||
|
- [Linux Service](#-linux-service)
|
||||||
|
- [Project Structure](#-project-structure)
|
||||||
|
- [Contribute & Roadmap](#-contribute--roadmap)
|
||||||
|
- [Star History](#-star-history)
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
|
|
||||||
<table align="center">
|
<table align="center">
|
||||||
@@ -208,6 +227,7 @@ Connect nanobot to your favorite chat platform.
|
|||||||
| **Slack** | Bot token + App-Level token |
|
| **Slack** | Bot token + App-Level token |
|
||||||
| **Email** | IMAP/SMTP credentials |
|
| **Email** | IMAP/SMTP credentials |
|
||||||
| **QQ** | App ID + App Secret |
|
| **QQ** | App ID + App Secret |
|
||||||
|
| **Wecom** | Bot ID + Bot Secret |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Telegram</b> (Recommended)</summary>
|
<summary><b>Telegram</b> (Recommended)</summary>
|
||||||
@@ -677,6 +697,46 @@ nanobot gateway
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Wecom (企业微信)</b></summary>
|
||||||
|
|
||||||
|
> Here we use [wecom-aibot-sdk-python](https://github.com/chengyongru/wecom_aibot_sdk) (community Python version of the official [@wecom/aibot-node-sdk](https://www.npmjs.com/package/@wecom/aibot-node-sdk)).
|
||||||
|
>
|
||||||
|
> Uses **WebSocket** long connection — no public IP required.
|
||||||
|
|
||||||
|
**1. Install the optional dependency**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install nanobot-ai[wecom]
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Create a WeCom AI Bot**
|
||||||
|
|
||||||
|
Go to the WeCom admin console → Intelligent Robot → Create Robot → select **API mode** with **long connection**. Copy the Bot ID and Secret.
|
||||||
|
|
||||||
|
**3. Configure**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"wecom": {
|
||||||
|
"enabled": true,
|
||||||
|
"botId": "your_bot_id",
|
||||||
|
"secret": "your_bot_secret",
|
||||||
|
"allowFrom": ["your_id"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**4. Run**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 🌐 Agent Social Network
|
## 🌐 Agent Social Network
|
||||||
|
|
||||||
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
|
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
|
||||||
@@ -718,6 +778,7 @@ Config file: `~/.nanobot/config.json`
|
|||||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
|
| `ollama` | LLM (local, Ollama) | — |
|
||||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||||
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
|
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
|
||||||
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
||||||
@@ -783,6 +844,37 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Ollama (local)</b></summary>
|
||||||
|
|
||||||
|
Run a local model with Ollama, then add to config:
|
||||||
|
|
||||||
|
**1. Start Ollama** (example):
|
||||||
|
```bash
|
||||||
|
ollama run llama3.2
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"apiBase": "http://localhost:11434"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"model": "llama3.2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> `provider: "auto"` also works when `providers.ollama.apiBase` is configured, but setting `"provider": "ollama"` is the clearest option.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
from nanobot.utils.helpers import detect_image_mime
|
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
@@ -182,12 +182,10 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
thinking_blocks: list[dict] | None = None,
|
thinking_blocks: list[dict] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Add an assistant message to the message list."""
|
"""Add an assistant message to the message list."""
|
||||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
messages.append(build_assistant_message(
|
||||||
if tool_calls:
|
content,
|
||||||
msg["tool_calls"] = tool_calls
|
tool_calls=tool_calls,
|
||||||
if reasoning_content is not None:
|
reasoning_content=reasoning_content,
|
||||||
msg["reasoning_content"] = reasoning_content
|
thinking_blocks=thinking_blocks,
|
||||||
if thinking_blocks:
|
))
|
||||||
msg["thinking_blocks"] = thinking_blocks
|
|
||||||
messages.append(msg)
|
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import weakref
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
@@ -13,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryConsolidator
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
@@ -53,10 +52,7 @@ class AgentLoop:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 40,
|
max_iterations: int = 40,
|
||||||
temperature: float = 0.1,
|
context_window_tokens: int = 65_536,
|
||||||
max_tokens: int = 4096,
|
|
||||||
memory_window: int = 100,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
@@ -73,10 +69,7 @@ class AgentLoop:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.temperature = temperature
|
self.context_window_tokens = context_window_tokens
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.memory_window = memory_window
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@@ -91,9 +84,6 @@ class AgentLoop:
|
|||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
reasoning_effort=reasoning_effort,
|
|
||||||
brave_api_key=brave_api_key,
|
brave_api_key=brave_api_key,
|
||||||
web_proxy=web_proxy,
|
web_proxy=web_proxy,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
@@ -105,11 +95,17 @@ class AgentLoop:
|
|||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
|
||||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
|
||||||
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
|
self.memory_consolidator = MemoryConsolidator(
|
||||||
|
workspace=workspace,
|
||||||
|
provider=provider,
|
||||||
|
model=self.model,
|
||||||
|
sessions=self.sessions,
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
build_messages=self.context.build_messages,
|
||||||
|
get_tool_definitions=self.tools.get_definitions,
|
||||||
|
)
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@@ -182,7 +178,7 @@ class AgentLoop:
|
|||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
"""Run the agent iteration loop."""
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
@@ -191,13 +187,12 @@ class AgentLoop:
|
|||||||
while iteration < self.max_iterations:
|
while iteration < self.max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
|
tool_defs = self.tools.get_definitions()
|
||||||
|
|
||||||
response = await self.provider.chat_with_retry(
|
response = await self.provider.chat_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tools.get_definitions(),
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
reasoning_effort=self.reasoning_effort,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
@@ -334,8 +329,9 @@ class AgentLoop:
|
|||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=self.memory_window)
|
history = session.get_history(max_messages=0)
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
@@ -343,6 +339,7 @@ class AgentLoop:
|
|||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.")
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
@@ -355,27 +352,20 @@ class AgentLoop:
|
|||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
self._consolidating.add(session.key)
|
|
||||||
try:
|
try:
|
||||||
async with lock:
|
if not await self.memory_consolidator.archive_unconsolidated(session):
|
||||||
snapshot = session.messages[session.last_consolidated:]
|
return OutboundMessage(
|
||||||
if snapshot:
|
channel=msg.channel,
|
||||||
temp = Session(key=session.key)
|
chat_id=msg.chat_id,
|
||||||
temp.messages = list(snapshot)
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
if not await self._consolidate_memory(temp, archive_all=True):
|
)
|
||||||
return OutboundMessage(
|
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
|
||||||
content="Memory archival failed, session not cleared. Please try again.",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("/new archival failed for {}", session.key)
|
logger.exception("/new archival failed for {}", session.key)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
content="Memory archival failed, session not cleared. Please try again.",
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self._consolidating.discard(session.key)
|
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@@ -386,30 +376,14 @@ class AgentLoop:
|
|||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||||
|
|
||||||
unconsolidated = len(session.messages) - session.last_consolidated
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
|
||||||
self._consolidating.add(session.key)
|
|
||||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
|
||||||
try:
|
|
||||||
async with lock:
|
|
||||||
await self._consolidate_memory(session)
|
|
||||||
finally:
|
|
||||||
self._consolidating.discard(session.key)
|
|
||||||
_task = asyncio.current_task()
|
|
||||||
if _task is not None:
|
|
||||||
self._consolidation_tasks.discard(_task)
|
|
||||||
|
|
||||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
|
||||||
self._consolidation_tasks.add(_task)
|
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if isinstance(message_tool, MessageTool):
|
||||||
message_tool.start_turn()
|
message_tool.start_turn()
|
||||||
|
|
||||||
history = session.get_history(max_messages=self.memory_window)
|
history = session.get_history(max_messages=0)
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
@@ -434,6 +408,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
return None
|
||||||
@@ -480,13 +455,6 @@ class AgentLoop:
|
|||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
|
||||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
|
||||||
return await MemoryStore(self.workspace).consolidate(
|
|
||||||
session, self.provider, self.model,
|
|
||||||
archive_all=archive_all, memory_window=self.memory_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|||||||
@@ -2,17 +2,19 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.utils.helpers import ensure_dir
|
from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.session.manager import Session
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
|
||||||
_SAVE_MEMORY_TOOL = [
|
_SAVE_MEMORY_TOOL = [
|
||||||
@@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [
|
|||||||
"properties": {
|
"properties": {
|
||||||
"history_entry": {
|
"history_entry": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
|
"description": "A paragraph summarizing key events/decisions/topics. "
|
||||||
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||||
},
|
},
|
||||||
"memory_update": {
|
"memory_update": {
|
||||||
@@ -42,6 +44,19 @@ _SAVE_MEMORY_TOOL = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_text(value: Any) -> str:
|
||||||
|
"""Normalize tool-call payload values to text for file storage."""
|
||||||
|
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
||||||
|
"""Normalize provider tool-call arguments to the expected dict shape."""
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
if isinstance(args, list):
|
||||||
|
return args[0] if args and isinstance(args[0], dict) else None
|
||||||
|
return args if isinstance(args, dict) else None
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
|
|
||||||
@@ -66,40 +81,27 @@ class MemoryStore:
|
|||||||
long_term = self.read_long_term()
|
long_term = self.read_long_term()
|
||||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_messages(messages: list[dict]) -> str:
|
||||||
|
lines = []
|
||||||
|
for message in messages:
|
||||||
|
if not message.get("content"):
|
||||||
|
continue
|
||||||
|
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
|
||||||
|
lines.append(
|
||||||
|
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
async def consolidate(
|
async def consolidate(
|
||||||
self,
|
self,
|
||||||
session: Session,
|
messages: list[dict],
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
|
||||||
archive_all: bool = False,
|
|
||||||
memory_window: int = 50,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
"""Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
|
||||||
|
if not messages:
|
||||||
Returns True on success (including no-op), False on failure.
|
return True
|
||||||
"""
|
|
||||||
if archive_all:
|
|
||||||
old_messages = session.messages
|
|
||||||
keep_count = 0
|
|
||||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
|
||||||
else:
|
|
||||||
keep_count = memory_window // 2
|
|
||||||
if len(session.messages) <= keep_count:
|
|
||||||
return True
|
|
||||||
if len(session.messages) - session.last_consolidated <= 0:
|
|
||||||
return True
|
|
||||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
|
||||||
if not old_messages:
|
|
||||||
return True
|
|
||||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for m in old_messages:
|
|
||||||
if not m.get("content"):
|
|
||||||
continue
|
|
||||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
|
||||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
|
||||||
|
|
||||||
current_memory = self.read_long_term()
|
current_memory = self.read_long_term()
|
||||||
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||||
@@ -108,7 +110,7 @@ class MemoryStore:
|
|||||||
{current_memory or "(empty)"}
|
{current_memory or "(empty)"}
|
||||||
|
|
||||||
## Conversation to Process
|
## Conversation to Process
|
||||||
{chr(10).join(lines)}"""
|
{self._format_messages(messages)}"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await provider.chat_with_retry(
|
response = await provider.chat_with_retry(
|
||||||
@@ -124,34 +126,158 @@ class MemoryStore:
|
|||||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
args = response.tool_calls[0].arguments
|
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
||||||
# Some providers return arguments as a JSON string instead of dict
|
if args is None:
|
||||||
if isinstance(args, str):
|
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
||||||
args = json.loads(args)
|
|
||||||
# Some providers return arguments as a list (handle edge case)
|
|
||||||
if isinstance(args, list):
|
|
||||||
if args and isinstance(args[0], dict):
|
|
||||||
args = args[0]
|
|
||||||
else:
|
|
||||||
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
|
||||||
return False
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if entry := args.get("history_entry"):
|
||||||
if not isinstance(entry, str):
|
self.append_history(_ensure_text(entry))
|
||||||
entry = json.dumps(entry, ensure_ascii=False)
|
|
||||||
self.append_history(entry)
|
|
||||||
if update := args.get("memory_update"):
|
if update := args.get("memory_update"):
|
||||||
if not isinstance(update, str):
|
update = _ensure_text(update)
|
||||||
update = json.dumps(update, ensure_ascii=False)
|
|
||||||
if update != current_memory:
|
if update != current_memory:
|
||||||
self.write_long_term(update)
|
self.write_long_term(update)
|
||||||
|
|
||||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Memory consolidation failed")
|
logger.exception("Memory consolidation failed")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConsolidator:
|
||||||
|
"""Owns consolidation policy, locking, and session offset updates."""
|
||||||
|
|
||||||
|
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace: Path,
|
||||||
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
sessions: SessionManager,
|
||||||
|
context_window_tokens: int,
|
||||||
|
build_messages: Callable[..., list[dict[str, Any]]],
|
||||||
|
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||||
|
):
|
||||||
|
self.store = MemoryStore(workspace)
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.sessions = sessions
|
||||||
|
self.context_window_tokens = context_window_tokens
|
||||||
|
self._build_messages = build_messages
|
||||||
|
self._get_tool_definitions = get_tool_definitions
|
||||||
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||||
|
|
||||||
|
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
|
"""Return the shared consolidation lock for one session."""
|
||||||
|
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
|
||||||
|
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||||
|
"""Archive a selected message chunk into persistent memory."""
|
||||||
|
return await self.store.consolidate(messages, self.provider, self.model)
|
||||||
|
|
||||||
|
def pick_consolidation_boundary(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
tokens_to_remove: int,
|
||||||
|
) -> tuple[int, int] | None:
|
||||||
|
"""Pick a user-turn boundary that removes enough old prompt tokens."""
|
||||||
|
start = session.last_consolidated
|
||||||
|
if start >= len(session.messages) or tokens_to_remove <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
removed_tokens = 0
|
||||||
|
last_boundary: tuple[int, int] | None = None
|
||||||
|
for idx in range(start, len(session.messages)):
|
||||||
|
message = session.messages[idx]
|
||||||
|
if idx > start and message.get("role") == "user":
|
||||||
|
last_boundary = (idx, removed_tokens)
|
||||||
|
if removed_tokens >= tokens_to_remove:
|
||||||
|
return last_boundary
|
||||||
|
removed_tokens += estimate_message_tokens(message)
|
||||||
|
|
||||||
|
return last_boundary
|
||||||
|
|
||||||
|
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||||
|
"""Estimate current prompt size for the normal session history view."""
|
||||||
|
history = session.get_history(max_messages=0)
|
||||||
|
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||||
|
probe_messages = self._build_messages(
|
||||||
|
history=history,
|
||||||
|
current_message="[token-probe]",
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
)
|
||||||
|
return estimate_prompt_tokens_chain(
|
||||||
|
self.provider,
|
||||||
|
self.model,
|
||||||
|
probe_messages,
|
||||||
|
self._get_tool_definitions(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def archive_unconsolidated(self, session: Session) -> bool:
|
||||||
|
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||||
|
lock = self.get_lock(session.key)
|
||||||
|
async with lock:
|
||||||
|
snapshot = session.messages[session.last_consolidated:]
|
||||||
|
if not snapshot:
|
||||||
|
return True
|
||||||
|
return await self.consolidate_messages(snapshot)
|
||||||
|
|
||||||
|
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||||
|
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||||
|
if not session.messages or self.context_window_tokens <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
lock = self.get_lock(session.key)
|
||||||
|
async with lock:
|
||||||
|
target = self.context_window_tokens // 2
|
||||||
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
|
if estimated <= 0:
|
||||||
|
return
|
||||||
|
if estimated < self.context_window_tokens:
|
||||||
|
logger.debug(
|
||||||
|
"Token consolidation idle {}: {}/{} via {}",
|
||||||
|
session.key,
|
||||||
|
estimated,
|
||||||
|
self.context_window_tokens,
|
||||||
|
source,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
|
||||||
|
if estimated <= target:
|
||||||
|
return
|
||||||
|
|
||||||
|
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
|
||||||
|
if boundary is None:
|
||||||
|
logger.debug(
|
||||||
|
"Token consolidation: no safe boundary for {} (round {})",
|
||||||
|
session.key,
|
||||||
|
round_num,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
end_idx = boundary[0]
|
||||||
|
chunk = session.messages[session.last_consolidated:end_idx]
|
||||||
|
if not chunk:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
|
||||||
|
round_num,
|
||||||
|
session.key,
|
||||||
|
estimated,
|
||||||
|
self.context_window_tokens,
|
||||||
|
source,
|
||||||
|
len(chunk),
|
||||||
|
)
|
||||||
|
if not await self.consolidate_messages(chunk):
|
||||||
|
return
|
||||||
|
session.last_consolidated = end_idx
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
|
if estimated <= 0:
|
||||||
|
return
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from nanobot.bus.events import InboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.utils.helpers import build_assistant_message
|
||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
@@ -27,9 +28,6 @@ class SubagentManager:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
@@ -40,9 +38,6 @@ class SubagentManager:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@@ -127,22 +122,19 @@ class SubagentManager:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools.get_definitions(),
|
tools=tools.get_definitions(),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
reasoning_effort=self.reasoning_effort,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
# Add assistant message with tool calls
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
tc.to_openai_tool_call()
|
tc.to_openai_tool_call()
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
]
|
]
|
||||||
messages.append({
|
messages.append(build_assistant_message(
|
||||||
"role": "assistant",
|
response.content or "",
|
||||||
"content": response.content or "",
|
tool_calls=tool_call_dicts,
|
||||||
"tool_calls": tool_call_dicts,
|
reasoning_content=response.reasoning_content,
|
||||||
})
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
))
|
||||||
|
|
||||||
# Execute tools
|
# Execute tools
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
"""Base channel interface for chat platforms."""
|
"""Base channel interface for chat platforms."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -18,6 +21,8 @@ class BaseChannel(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "base"
|
name: str = "base"
|
||||||
|
display_name: str = "Base"
|
||||||
|
transcription_api_key: str = ""
|
||||||
|
|
||||||
def __init__(self, config: Any, bus: MessageBus):
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
"""
|
"""
|
||||||
@@ -31,6 +36,19 @@ class BaseChannel(ABC):
|
|||||||
self.bus = bus
|
self.bus = bus
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
async def transcribe_audio(self, file_path: str | Path) -> str:
|
||||||
|
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
|
||||||
|
if not self.transcription_api_key:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||||
|
|
||||||
|
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
|
||||||
|
return await provider.transcribe(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("{}: audio transcription failed: {}", self.name, e)
|
||||||
|
return ""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
content = ""
|
content = ""
|
||||||
if chatbot_msg.text:
|
if chatbot_msg.text:
|
||||||
content = chatbot_msg.text.content.strip()
|
content = chatbot_msg.text.content.strip()
|
||||||
|
elif chatbot_msg.extensions.get("content", {}).get("recognition"):
|
||||||
|
content = chatbot_msg.extensions["content"]["recognition"].strip()
|
||||||
if not content:
|
if not content:
|
||||||
content = message.data.get("text", {}).get("content", "").strip()
|
content = message.data.get("text", {}).get("content", "").strip()
|
||||||
|
|
||||||
@@ -112,6 +114,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "dingtalk"
|
name = "dingtalk"
|
||||||
|
display_name = "DingTalk"
|
||||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
"""Discord channel using Gateway websocket."""
|
"""Discord channel using Gateway websocket."""
|
||||||
|
|
||||||
name = "discord"
|
name = "discord"
|
||||||
|
display_name = "Discord"
|
||||||
|
|
||||||
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class EmailChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "email"
|
name = "email"
|
||||||
|
display_name = "Email"
|
||||||
_IMAP_MONTHS = (
|
_IMAP_MONTHS = (
|
||||||
"Jan",
|
"Jan",
|
||||||
"Feb",
|
"Feb",
|
||||||
|
|||||||
@@ -244,11 +244,11 @@ class FeishuChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "feishu"
|
name = "feishu"
|
||||||
|
display_name = "Feishu"
|
||||||
|
|
||||||
def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""):
|
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.config: FeishuConfig = config
|
self.config: FeishuConfig = config
|
||||||
self.groq_api_key = groq_api_key
|
|
||||||
self._client: Any = None
|
self._client: Any = None
|
||||||
self._ws_client: Any = None
|
self._ws_client: Any = None
|
||||||
self._ws_thread: threading.Thread | None = None
|
self._ws_thread: threading.Thread | None = None
|
||||||
@@ -928,16 +928,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
if file_path:
|
if file_path:
|
||||||
media_paths.append(file_path)
|
media_paths.append(file_path)
|
||||||
|
|
||||||
# Transcribe audio using Groq Whisper
|
if msg_type == "audio" and file_path:
|
||||||
if msg_type == "audio" and file_path and self.groq_api_key:
|
transcription = await self.transcribe_audio(file_path)
|
||||||
try:
|
if transcription:
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
content_text = f"[transcription: {transcription}]"
|
||||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
|
||||||
transcription = await transcriber.transcribe(file_path)
|
|
||||||
if transcription:
|
|
||||||
content_text = f"[transcription: {transcription}]"
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to transcribe audio: {}", e)
|
|
||||||
|
|
||||||
content_parts.append(content_text)
|
content_parts.append(content_text)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
@@ -32,123 +31,23 @@ class ChannelManager:
|
|||||||
self._init_channels()
|
self._init_channels()
|
||||||
|
|
||||||
def _init_channels(self) -> None:
|
def _init_channels(self) -> None:
|
||||||
"""Initialize channels based on config."""
|
"""Initialize channels discovered via pkgutil scan."""
|
||||||
|
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||||
|
|
||||||
# Telegram channel
|
groq_key = self.config.providers.groq.api_key
|
||||||
if self.config.channels.telegram.enabled:
|
|
||||||
|
for modname in discover_channel_names():
|
||||||
|
section = getattr(self.config.channels, modname, None)
|
||||||
|
if not section or not getattr(section, "enabled", False):
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
from nanobot.channels.telegram import TelegramChannel
|
cls = load_channel_class(modname)
|
||||||
self.channels["telegram"] = TelegramChannel(
|
channel = cls(section, self.bus)
|
||||||
self.config.channels.telegram,
|
channel.transcription_api_key = groq_key
|
||||||
self.bus,
|
self.channels[modname] = channel
|
||||||
groq_api_key=self.config.providers.groq.api_key,
|
logger.info("{} channel enabled", cls.display_name)
|
||||||
)
|
|
||||||
logger.info("Telegram channel enabled")
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Telegram channel not available: {}", e)
|
logger.warning("{} channel not available: {}", modname, e)
|
||||||
|
|
||||||
# WhatsApp channel
|
|
||||||
if self.config.channels.whatsapp.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
|
||||||
self.channels["whatsapp"] = WhatsAppChannel(
|
|
||||||
self.config.channels.whatsapp, self.bus
|
|
||||||
)
|
|
||||||
logger.info("WhatsApp channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("WhatsApp channel not available: {}", e)
|
|
||||||
|
|
||||||
# Discord channel
|
|
||||||
if self.config.channels.discord.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.discord import DiscordChannel
|
|
||||||
self.channels["discord"] = DiscordChannel(
|
|
||||||
self.config.channels.discord, self.bus
|
|
||||||
)
|
|
||||||
logger.info("Discord channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Discord channel not available: {}", e)
|
|
||||||
|
|
||||||
# Feishu channel
|
|
||||||
if self.config.channels.feishu.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.feishu import FeishuChannel
|
|
||||||
self.channels["feishu"] = FeishuChannel(
|
|
||||||
self.config.channels.feishu, self.bus,
|
|
||||||
groq_api_key=self.config.providers.groq.api_key,
|
|
||||||
)
|
|
||||||
logger.info("Feishu channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Feishu channel not available: {}", e)
|
|
||||||
|
|
||||||
# Mochat channel
|
|
||||||
if self.config.channels.mochat.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.mochat import MochatChannel
|
|
||||||
|
|
||||||
self.channels["mochat"] = MochatChannel(
|
|
||||||
self.config.channels.mochat, self.bus
|
|
||||||
)
|
|
||||||
logger.info("Mochat channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Mochat channel not available: {}", e)
|
|
||||||
|
|
||||||
# DingTalk channel
|
|
||||||
if self.config.channels.dingtalk.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.dingtalk import DingTalkChannel
|
|
||||||
self.channels["dingtalk"] = DingTalkChannel(
|
|
||||||
self.config.channels.dingtalk, self.bus
|
|
||||||
)
|
|
||||||
logger.info("DingTalk channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("DingTalk channel not available: {}", e)
|
|
||||||
|
|
||||||
# Email channel
|
|
||||||
if self.config.channels.email.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.email import EmailChannel
|
|
||||||
self.channels["email"] = EmailChannel(
|
|
||||||
self.config.channels.email, self.bus
|
|
||||||
)
|
|
||||||
logger.info("Email channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Email channel not available: {}", e)
|
|
||||||
|
|
||||||
# Slack channel
|
|
||||||
if self.config.channels.slack.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.slack import SlackChannel
|
|
||||||
self.channels["slack"] = SlackChannel(
|
|
||||||
self.config.channels.slack, self.bus
|
|
||||||
)
|
|
||||||
logger.info("Slack channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Slack channel not available: {}", e)
|
|
||||||
|
|
||||||
# QQ channel
|
|
||||||
if self.config.channels.qq.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.qq import QQChannel
|
|
||||||
self.channels["qq"] = QQChannel(
|
|
||||||
self.config.channels.qq,
|
|
||||||
self.bus,
|
|
||||||
)
|
|
||||||
logger.info("QQ channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("QQ channel not available: {}", e)
|
|
||||||
|
|
||||||
# Matrix channel
|
|
||||||
if self.config.channels.matrix.enabled:
|
|
||||||
try:
|
|
||||||
from nanobot.channels.matrix import MatrixChannel
|
|
||||||
self.channels["matrix"] = MatrixChannel(
|
|
||||||
self.config.channels.matrix,
|
|
||||||
self.bus,
|
|
||||||
)
|
|
||||||
logger.info("Matrix channel enabled")
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Matrix channel not available: {}", e)
|
|
||||||
|
|
||||||
self._validate_allow_from()
|
self._validate_allow_from()
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ except ImportError as e:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||||
from nanobot.utils.helpers import safe_filename
|
from nanobot.utils.helpers import safe_filename
|
||||||
@@ -146,15 +147,15 @@ class MatrixChannel(BaseChannel):
|
|||||||
"""Matrix (Element) channel using long-polling sync."""
|
"""Matrix (Element) channel using long-polling sync."""
|
||||||
|
|
||||||
name = "matrix"
|
name = "matrix"
|
||||||
|
display_name = "Matrix"
|
||||||
|
|
||||||
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
workspace: Path | None = None):
|
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.client: AsyncClient | None = None
|
self.client: AsyncClient | None = None
|
||||||
self._sync_task: asyncio.Task | None = None
|
self._sync_task: asyncio.Task | None = None
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||||
self._restrict_to_workspace = restrict_to_workspace
|
self._restrict_to_workspace = False
|
||||||
self._workspace = workspace.expanduser().resolve() if workspace else None
|
self._workspace: Path | None = None
|
||||||
self._server_upload_limit_bytes: int | None = None
|
self._server_upload_limit_bytes: int | None = None
|
||||||
self._server_upload_limit_checked = False
|
self._server_upload_limit_checked = False
|
||||||
|
|
||||||
@@ -677,7 +678,14 @@ class MatrixChannel(BaseChannel):
|
|||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||||
parts.append(body.strip())
|
parts.append(body.strip())
|
||||||
if marker:
|
|
||||||
|
if attachment and attachment.get("type") == "audio":
|
||||||
|
transcription = await self.transcribe_audio(attachment["path"])
|
||||||
|
if transcription:
|
||||||
|
parts.append(f"[transcription: {transcription}]")
|
||||||
|
else:
|
||||||
|
parts.append(marker)
|
||||||
|
elif marker:
|
||||||
parts.append(marker)
|
parts.append(marker)
|
||||||
|
|
||||||
await self._start_typing_keepalive(room.room_id)
|
await self._start_typing_keepalive(room.room_id)
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class MochatChannel(BaseChannel):
|
|||||||
"""Mochat channel using socket.io with fallback polling workers."""
|
"""Mochat channel using socket.io with fallback polling workers."""
|
||||||
|
|
||||||
name = "mochat"
|
name = "mochat"
|
||||||
|
display_name = "Mochat"
|
||||||
|
|
||||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class QQChannel(BaseChannel):
|
|||||||
"""QQ channel using botpy SDK with WebSocket connection."""
|
"""QQ channel using botpy SDK with WebSocket connection."""
|
||||||
|
|
||||||
name = "qq"
|
name = "qq"
|
||||||
|
display_name = "QQ"
|
||||||
|
|
||||||
def __init__(self, config: QQConfig, bus: MessageBus):
|
def __init__(self, config: QQConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
|
|||||||
35
nanobot/channels/registry.py
Normal file
35
nanobot/channels/registry.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Auto-discovery for channel modules — no hardcoded registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import pkgutil
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
|
||||||
|
_INTERNAL = frozenset({"base", "manager", "registry"})
|
||||||
|
|
||||||
|
|
||||||
|
def discover_channel_names() -> list[str]:
|
||||||
|
"""Return all channel module names by scanning the package (zero imports)."""
|
||||||
|
import nanobot.channels as pkg
|
||||||
|
|
||||||
|
return [
|
||||||
|
name
|
||||||
|
for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
|
||||||
|
if name not in _INTERNAL and not ispkg
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def load_channel_class(module_name: str) -> type[BaseChannel]:
|
||||||
|
"""Import *module_name* and return the first BaseChannel subclass found."""
|
||||||
|
from nanobot.channels.base import BaseChannel as _Base
|
||||||
|
|
||||||
|
mod = importlib.import_module(f"nanobot.channels.{module_name}")
|
||||||
|
for attr in dir(mod):
|
||||||
|
obj = getattr(mod, attr)
|
||||||
|
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
|
||||||
|
return obj
|
||||||
|
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
|
||||||
@@ -21,6 +21,7 @@ class SlackChannel(BaseChannel):
|
|||||||
"""Slack channel using Socket Mode."""
|
"""Slack channel using Socket Mode."""
|
||||||
|
|
||||||
name = "slack"
|
name = "slack"
|
||||||
|
display_name = "Slack"
|
||||||
|
|
||||||
def __init__(self, config: SlackConfig, bus: MessageBus):
|
def __init__(self, config: SlackConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "telegram"
|
name = "telegram"
|
||||||
|
display_name = "Telegram"
|
||||||
|
|
||||||
# Commands registered with Telegram's command menu
|
# Commands registered with Telegram's command menu
|
||||||
BOT_COMMANDS = [
|
BOT_COMMANDS = [
|
||||||
@@ -164,15 +165,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
BotCommand("help", "Show available commands"),
|
BotCommand("help", "Show available commands"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: TelegramConfig, bus: MessageBus):
|
||||||
self,
|
|
||||||
config: TelegramConfig,
|
|
||||||
bus: MessageBus,
|
|
||||||
groq_api_key: str = "",
|
|
||||||
):
|
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.config: TelegramConfig = config
|
self.config: TelegramConfig = config
|
||||||
self.groq_api_key = groq_api_key
|
|
||||||
self._app: Application | None = None
|
self._app: Application | None = None
|
||||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
@@ -615,11 +610,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
media_paths.append(str(file_path))
|
media_paths.append(str(file_path))
|
||||||
|
|
||||||
# Handle voice transcription
|
if media_type in ("voice", "audio"):
|
||||||
if media_type == "voice" or media_type == "audio":
|
transcription = await self.transcribe_audio(file_path)
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
|
||||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
|
||||||
transcription = await transcriber.transcribe(file_path)
|
|
||||||
if transcription:
|
if transcription:
|
||||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||||
content_parts.append(f"[transcription: {transcription}]")
|
content_parts.append(f"[transcription: {transcription}]")
|
||||||
|
|||||||
353
nanobot/channels/wecom.py
Normal file
353
nanobot/channels/wecom.py
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.config.paths import get_media_dir
|
||||||
|
from nanobot.config.schema import WecomConfig
|
||||||
|
|
||||||
|
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||||
|
|
||||||
|
# Message type display mapping
|
||||||
|
MSG_TYPE_MAP = {
|
||||||
|
"image": "[image]",
|
||||||
|
"voice": "[voice]",
|
||||||
|
"file": "[file]",
|
||||||
|
"mixed": "[mixed content]",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WecomChannel(BaseChannel):
|
||||||
|
"""
|
||||||
|
WeCom (Enterprise WeChat) channel using WebSocket long connection.
|
||||||
|
|
||||||
|
Uses WebSocket to receive events - no public IP or webhook required.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- Bot ID and Secret from WeCom AI Bot platform
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "wecom"
|
||||||
|
display_name = "WeCom"
|
||||||
|
|
||||||
|
def __init__(self, config: WecomConfig, bus: MessageBus):
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self.config: WecomConfig = config
|
||||||
|
self._client: Any = None
|
||||||
|
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
self._generate_req_id = None
|
||||||
|
# Store frame headers for each chat to enable replies
|
||||||
|
self._chat_frames: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the WeCom bot with WebSocket long connection."""
|
||||||
|
if not WECOM_AVAILABLE:
|
||||||
|
logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.config.bot_id or not self.config.secret:
|
||||||
|
logger.error("WeCom bot_id and secret not configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
from wecom_aibot_sdk import WSClient, generate_req_id
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
self._generate_req_id = generate_req_id
|
||||||
|
|
||||||
|
# Create WebSocket client
|
||||||
|
self._client = WSClient({
|
||||||
|
"bot_id": self.config.bot_id,
|
||||||
|
"secret": self.config.secret,
|
||||||
|
"reconnect_interval": 1000,
|
||||||
|
"max_reconnect_attempts": -1, # Infinite reconnect
|
||||||
|
"heartbeat_interval": 30000,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Register event handlers
|
||||||
|
self._client.on("connected", self._on_connected)
|
||||||
|
self._client.on("authenticated", self._on_authenticated)
|
||||||
|
self._client.on("disconnected", self._on_disconnected)
|
||||||
|
self._client.on("error", self._on_error)
|
||||||
|
self._client.on("message.text", self._on_text_message)
|
||||||
|
self._client.on("message.image", self._on_image_message)
|
||||||
|
self._client.on("message.voice", self._on_voice_message)
|
||||||
|
self._client.on("message.file", self._on_file_message)
|
||||||
|
self._client.on("message.mixed", self._on_mixed_message)
|
||||||
|
self._client.on("event.enter_chat", self._on_enter_chat)
|
||||||
|
|
||||||
|
logger.info("WeCom bot starting with WebSocket long connection")
|
||||||
|
logger.info("No public IP required - using WebSocket to receive events")
|
||||||
|
|
||||||
|
# Connect
|
||||||
|
await self._client.connect_async()
|
||||||
|
|
||||||
|
# Keep running until stopped
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the WeCom bot."""
|
||||||
|
self._running = False
|
||||||
|
if self._client:
|
||||||
|
await self._client.disconnect()
|
||||||
|
logger.info("WeCom bot stopped")
|
||||||
|
|
||||||
|
async def _on_connected(self, frame: Any) -> None:
|
||||||
|
"""Handle WebSocket connected event."""
|
||||||
|
logger.info("WeCom WebSocket connected")
|
||||||
|
|
||||||
|
async def _on_authenticated(self, frame: Any) -> None:
|
||||||
|
"""Handle authentication success event."""
|
||||||
|
logger.info("WeCom authenticated successfully")
|
||||||
|
|
||||||
|
async def _on_disconnected(self, frame: Any) -> None:
|
||||||
|
"""Handle WebSocket disconnected event."""
|
||||||
|
reason = frame.body if hasattr(frame, 'body') else str(frame)
|
||||||
|
logger.warning("WeCom WebSocket disconnected: {}", reason)
|
||||||
|
|
||||||
|
async def _on_error(self, frame: Any) -> None:
|
||||||
|
"""Handle error event."""
|
||||||
|
logger.error("WeCom error: {}", frame)
|
||||||
|
|
||||||
|
async def _on_text_message(self, frame: Any) -> None:
|
||||||
|
"""Handle text message."""
|
||||||
|
await self._process_message(frame, "text")
|
||||||
|
|
||||||
|
async def _on_image_message(self, frame: Any) -> None:
|
||||||
|
"""Handle image message."""
|
||||||
|
await self._process_message(frame, "image")
|
||||||
|
|
||||||
|
async def _on_voice_message(self, frame: Any) -> None:
|
||||||
|
"""Handle voice message."""
|
||||||
|
await self._process_message(frame, "voice")
|
||||||
|
|
||||||
|
async def _on_file_message(self, frame: Any) -> None:
|
||||||
|
"""Handle file message."""
|
||||||
|
await self._process_message(frame, "file")
|
||||||
|
|
||||||
|
async def _on_mixed_message(self, frame: Any) -> None:
|
||||||
|
"""Handle mixed content message."""
|
||||||
|
await self._process_message(frame, "mixed")
|
||||||
|
|
||||||
|
async def _on_enter_chat(self, frame: Any) -> None:
|
||||||
|
"""Handle enter_chat event (user opens chat with bot)."""
|
||||||
|
try:
|
||||||
|
# Extract body from WsFrame dataclass or dict
|
||||||
|
if hasattr(frame, 'body'):
|
||||||
|
body = frame.body or {}
|
||||||
|
elif isinstance(frame, dict):
|
||||||
|
body = frame.get("body", frame)
|
||||||
|
else:
|
||||||
|
body = {}
|
||||||
|
|
||||||
|
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
|
||||||
|
|
||||||
|
if chat_id and self.config.welcome_message:
|
||||||
|
await self._client.reply_welcome(frame, {
|
||||||
|
"msgtype": "text",
|
||||||
|
"text": {"content": self.config.welcome_message},
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error handling enter_chat: {}", e)
|
||||||
|
|
||||||
|
async def _process_message(self, frame: Any, msg_type: str) -> None:
|
||||||
|
"""Process incoming message and forward to bus."""
|
||||||
|
try:
|
||||||
|
# Extract body from WsFrame dataclass or dict
|
||||||
|
if hasattr(frame, 'body'):
|
||||||
|
body = frame.body or {}
|
||||||
|
elif isinstance(frame, dict):
|
||||||
|
body = frame.get("body", frame)
|
||||||
|
else:
|
||||||
|
body = {}
|
||||||
|
|
||||||
|
# Ensure body is a dict
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
logger.warning("Invalid body type: {}", type(body))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract message info
|
||||||
|
msg_id = body.get("msgid", "")
|
||||||
|
if not msg_id:
|
||||||
|
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
|
||||||
|
|
||||||
|
# Deduplication check
|
||||||
|
if msg_id in self._processed_message_ids:
|
||||||
|
return
|
||||||
|
self._processed_message_ids[msg_id] = None
|
||||||
|
|
||||||
|
# Trim cache
|
||||||
|
while len(self._processed_message_ids) > 1000:
|
||||||
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
|
# Extract sender info from "from" field (SDK format)
|
||||||
|
from_info = body.get("from", {})
|
||||||
|
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
|
||||||
|
|
||||||
|
# For single chat, chatid is the sender's userid
|
||||||
|
# For group chat, chatid is provided in body
|
||||||
|
chat_type = body.get("chattype", "single")
|
||||||
|
chat_id = body.get("chatid", sender_id)
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
|
||||||
|
if msg_type == "text":
|
||||||
|
text = body.get("text", {}).get("content", "")
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
|
||||||
|
elif msg_type == "image":
|
||||||
|
image_info = body.get("image", {})
|
||||||
|
file_url = image_info.get("url", "")
|
||||||
|
aes_key = image_info.get("aeskey", "")
|
||||||
|
|
||||||
|
if file_url and aes_key:
|
||||||
|
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||||
|
if file_path:
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
||||||
|
else:
|
||||||
|
content_parts.append("[image: download failed]")
|
||||||
|
else:
|
||||||
|
content_parts.append("[image: download failed]")
|
||||||
|
|
||||||
|
elif msg_type == "voice":
|
||||||
|
voice_info = body.get("voice", {})
|
||||||
|
# Voice message already contains transcribed content from WeCom
|
||||||
|
voice_content = voice_info.get("content", "")
|
||||||
|
if voice_content:
|
||||||
|
content_parts.append(f"[voice] {voice_content}")
|
||||||
|
else:
|
||||||
|
content_parts.append("[voice]")
|
||||||
|
|
||||||
|
elif msg_type == "file":
|
||||||
|
file_info = body.get("file", {})
|
||||||
|
file_url = file_info.get("url", "")
|
||||||
|
aes_key = file_info.get("aeskey", "")
|
||||||
|
file_name = file_info.get("name", "unknown")
|
||||||
|
|
||||||
|
if file_url and aes_key:
|
||||||
|
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||||
|
if file_path:
|
||||||
|
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||||
|
else:
|
||||||
|
content_parts.append(f"[file: {file_name}: download failed]")
|
||||||
|
else:
|
||||||
|
content_parts.append(f"[file: {file_name}: download failed]")
|
||||||
|
|
||||||
|
elif msg_type == "mixed":
|
||||||
|
# Mixed content contains multiple message items
|
||||||
|
msg_items = body.get("mixed", {}).get("item", [])
|
||||||
|
for item in msg_items:
|
||||||
|
item_type = item.get("type", "")
|
||||||
|
if item_type == "text":
|
||||||
|
text = item.get("text", {}).get("content", "")
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
else:
|
||||||
|
content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
|
||||||
|
|
||||||
|
else:
|
||||||
|
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||||
|
|
||||||
|
content = "\n".join(content_parts) if content_parts else ""
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store frame for this chat to enable replies
|
||||||
|
self._chat_frames[chat_id] = frame
|
||||||
|
|
||||||
|
# Forward to message bus
|
||||||
|
# Note: media paths are included in content for broader model compatibility
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=sender_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=content,
|
||||||
|
media=None,
|
||||||
|
metadata={
|
||||||
|
"message_id": msg_id,
|
||||||
|
"msg_type": msg_type,
|
||||||
|
"chat_type": chat_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing WeCom message: {}", e)
|
||||||
|
|
||||||
|
async def _download_and_save_media(
|
||||||
|
self,
|
||||||
|
file_url: str,
|
||||||
|
aes_key: str,
|
||||||
|
media_type: str,
|
||||||
|
filename: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Download and decrypt media from WeCom.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file_path or None if download failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data, fname = await self._client.download_file(file_url, aes_key)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
logger.warning("Failed to download media from WeCom")
|
||||||
|
return None
|
||||||
|
|
||||||
|
media_dir = get_media_dir("wecom")
|
||||||
|
if not filename:
|
||||||
|
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
|
||||||
|
file_path = media_dir / filename
|
||||||
|
file_path.write_bytes(data)
|
||||||
|
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||||
|
return str(file_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error downloading media: {}", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
"""Send a message through WeCom."""
|
||||||
|
if not self._client:
|
||||||
|
logger.warning("WeCom client not initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = msg.content.strip()
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the stored frame for this chat
|
||||||
|
frame = self._chat_frames.get(msg.chat_id)
|
||||||
|
if not frame:
|
||||||
|
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use streaming reply for better UX
|
||||||
|
stream_id = self._generate_req_id("stream")
|
||||||
|
|
||||||
|
# Send as streaming message with finish=True
|
||||||
|
await self._client.reply_stream(
|
||||||
|
frame,
|
||||||
|
stream_id,
|
||||||
|
content,
|
||||||
|
finish=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("WeCom message sent to {}", msg.chat_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error sending WeCom message: {}", e)
|
||||||
@@ -22,6 +22,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "whatsapp"
|
name = "whatsapp"
|
||||||
|
display_name = "WhatsApp"
|
||||||
|
|
||||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
|
|||||||
@@ -191,6 +191,8 @@ def onboard():
|
|||||||
save_config(Config())
|
save_config(Config())
|
||||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||||
|
|
||||||
|
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
|
||||||
|
|
||||||
# Create workspace
|
# Create workspace
|
||||||
workspace = get_workspace_path()
|
workspace = get_workspace_path()
|
||||||
|
|
||||||
@@ -213,6 +215,7 @@ def onboard():
|
|||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
|
from nanobot.providers.base import GenerationSettings
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
@@ -222,46 +225,50 @@ def _make_provider(config: Config):
|
|||||||
|
|
||||||
# OpenAI Codex (OAuth)
|
# OpenAI Codex (OAuth)
|
||||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||||
return OpenAICodexProvider(default_model=model)
|
provider = OpenAICodexProvider(default_model=model)
|
||||||
|
|
||||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
elif provider_name == "custom":
|
||||||
if provider_name == "custom":
|
from nanobot.providers.custom_provider import CustomProvider
|
||||||
return CustomProvider(
|
provider = CustomProvider(
|
||||||
api_key=p.api_key if p else "no-key",
|
api_key=p.api_key if p else "no-key",
|
||||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||||
if provider_name == "azure_openai":
|
elif provider_name == "azure_openai":
|
||||||
if not p or not p.api_key or not p.api_base:
|
if not p or not p.api_key or not p.api_base:
|
||||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||||
console.print("Use the model field to specify the deployment name.")
|
console.print("Use the model field to specify the deployment name.")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
return AzureOpenAIProvider(
|
|
||||||
api_key=p.api_key,
|
api_key=p.api_key,
|
||||||
api_base=p.api_base,
|
api_base=p.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
|
from nanobot.providers.registry import find_by_name
|
||||||
|
spec = find_by_name(provider_name)
|
||||||
|
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
|
||||||
|
console.print("[red]Error: No API key configured.[/red]")
|
||||||
|
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
provider = LiteLLMProvider(
|
||||||
|
api_key=p.api_key if p else None,
|
||||||
|
api_base=config.get_api_base(model),
|
||||||
|
default_model=model,
|
||||||
|
extra_headers=p.extra_headers if p else None,
|
||||||
|
provider_name=provider_name,
|
||||||
|
)
|
||||||
|
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
defaults = config.agents.defaults
|
||||||
from nanobot.providers.registry import find_by_name
|
provider.generation = GenerationSettings(
|
||||||
spec = find_by_name(provider_name)
|
temperature=defaults.temperature,
|
||||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
max_tokens=defaults.max_tokens,
|
||||||
console.print("[red]Error: No API key configured.[/red]")
|
reasoning_effort=defaults.reasoning_effort,
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
return LiteLLMProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
provider_name=provider_name,
|
|
||||||
)
|
)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||||
@@ -283,6 +290,16 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
|
|||||||
return loaded
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
|
def _print_deprecated_memory_window_notice(config: Config) -> None:
|
||||||
|
"""Warn when running with old memoryWindow-only config."""
|
||||||
|
if config.agents.defaults.should_warn_deprecated_memory_window:
|
||||||
|
console.print(
|
||||||
|
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
|
||||||
|
"`contextWindowTokens`. `memoryWindow` is ignored; run "
|
||||||
|
"[cyan]nanobot onboard[/cyan] to refresh your config template."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Gateway / Server
|
# Gateway / Server
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -310,6 +327,7 @@ def gateway(
|
|||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
config = _load_runtime_config(config, workspace)
|
config = _load_runtime_config(config, workspace)
|
||||||
|
_print_deprecated_memory_window_notice(config)
|
||||||
port = port if port is not None else config.gateway.port
|
port = port if port is not None else config.gateway.port
|
||||||
|
|
||||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||||
@@ -328,11 +346,8 @@ def gateway(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
temperature=config.agents.defaults.temperature,
|
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
@@ -494,6 +509,7 @@ def agent(
|
|||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
config = _load_runtime_config(config, workspace)
|
config = _load_runtime_config(config, workspace)
|
||||||
|
_print_deprecated_memory_window_notice(config)
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
@@ -513,11 +529,8 @@ def agent(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
temperature=config.agents.defaults.temperature,
|
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
@@ -670,6 +683,7 @@ app.add_typer(channels_app, name="channels")
|
|||||||
@channels_app.command("status")
|
@channels_app.command("status")
|
||||||
def channels_status():
|
def channels_status():
|
||||||
"""Show channel status."""
|
"""Show channel status."""
|
||||||
|
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||||
from nanobot.config.loader import load_config
|
from nanobot.config.loader import load_config
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
@@ -677,85 +691,19 @@ def channels_status():
|
|||||||
table = Table(title="Channel Status")
|
table = Table(title="Channel Status")
|
||||||
table.add_column("Channel", style="cyan")
|
table.add_column("Channel", style="cyan")
|
||||||
table.add_column("Enabled", style="green")
|
table.add_column("Enabled", style="green")
|
||||||
table.add_column("Configuration", style="yellow")
|
|
||||||
|
|
||||||
# WhatsApp
|
for modname in sorted(discover_channel_names()):
|
||||||
wa = config.channels.whatsapp
|
section = getattr(config.channels, modname, None)
|
||||||
table.add_row(
|
enabled = section and getattr(section, "enabled", False)
|
||||||
"WhatsApp",
|
try:
|
||||||
"✓" if wa.enabled else "✗",
|
cls = load_channel_class(modname)
|
||||||
wa.bridge_url
|
display = cls.display_name
|
||||||
)
|
except ImportError:
|
||||||
|
display = modname.title()
|
||||||
dc = config.channels.discord
|
table.add_row(
|
||||||
table.add_row(
|
display,
|
||||||
"Discord",
|
"[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
|
||||||
"✓" if dc.enabled else "✗",
|
)
|
||||||
dc.gateway_url
|
|
||||||
)
|
|
||||||
|
|
||||||
# Feishu
|
|
||||||
fs = config.channels.feishu
|
|
||||||
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"Feishu",
|
|
||||||
"✓" if fs.enabled else "✗",
|
|
||||||
fs_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mochat
|
|
||||||
mc = config.channels.mochat
|
|
||||||
mc_base = mc.base_url or "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"Mochat",
|
|
||||||
"✓" if mc.enabled else "✗",
|
|
||||||
mc_base
|
|
||||||
)
|
|
||||||
|
|
||||||
# Telegram
|
|
||||||
tg = config.channels.telegram
|
|
||||||
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"Telegram",
|
|
||||||
"✓" if tg.enabled else "✗",
|
|
||||||
tg_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Slack
|
|
||||||
slack = config.channels.slack
|
|
||||||
slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"Slack",
|
|
||||||
"✓" if slack.enabled else "✗",
|
|
||||||
slack_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# DingTalk
|
|
||||||
dt = config.channels.dingtalk
|
|
||||||
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"DingTalk",
|
|
||||||
"✓" if dt.enabled else "✗",
|
|
||||||
dt_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# QQ
|
|
||||||
qq = config.channels.qq
|
|
||||||
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"QQ",
|
|
||||||
"✓" if qq.enabled else "✗",
|
|
||||||
qq_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Email
|
|
||||||
em = config.channels.email
|
|
||||||
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
|
|
||||||
table.add_row(
|
|
||||||
"Email",
|
|
||||||
"✓" if em.enabled else "✗",
|
|
||||||
em_config
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|||||||
@@ -200,6 +200,14 @@ class QQConfig(Base):
|
|||||||
) # Allowed user openids (empty = public access)
|
) # Allowed user openids (empty = public access)
|
||||||
|
|
||||||
|
|
||||||
|
class WecomConfig(Base):
|
||||||
|
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
bot_id: str = "" # Bot ID from WeCom AI Bot platform
|
||||||
|
secret: str = "" # Bot Secret from WeCom AI Bot platform
|
||||||
|
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
|
||||||
|
welcome_message: str = "" # Welcome message for enter_chat event
|
||||||
|
|
||||||
|
|
||||||
class ChannelsConfig(Base):
|
class ChannelsConfig(Base):
|
||||||
@@ -217,6 +225,7 @@ class ChannelsConfig(Base):
|
|||||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||||
qq: QQConfig = Field(default_factory=QQConfig)
|
qq: QQConfig = Field(default_factory=QQConfig)
|
||||||
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
|
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
|
||||||
|
wecom: WecomConfig = Field(default_factory=WecomConfig)
|
||||||
|
|
||||||
|
|
||||||
class AgentDefaults(Base):
|
class AgentDefaults(Base):
|
||||||
@@ -228,11 +237,18 @@ class AgentDefaults(Base):
|
|||||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||||
)
|
)
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
|
context_window_tokens: int = 65_536
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 40
|
||||||
memory_window: int = 100
|
# Deprecated compatibility field: accepted from old configs but ignored at runtime.
|
||||||
|
memory_window: int | None = Field(default=None, exclude=True)
|
||||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_warn_deprecated_memory_window(self) -> bool:
|
||||||
|
"""Return True when old memoryWindow is present without contextWindowTokens."""
|
||||||
|
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfig(Base):
|
class AgentsConfig(Base):
|
||||||
"""Agent configuration."""
|
"""Agent configuration."""
|
||||||
@@ -265,6 +281,7 @@ class ProvidersConfig(Base):
|
|||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
|
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||||
@@ -368,16 +385,25 @@ class Config(BaseSettings):
|
|||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and model_prefix and normalized_prefix == spec.name:
|
if p and model_prefix and normalized_prefix == spec.name:
|
||||||
if spec.is_oauth or p.api_key:
|
if spec.is_oauth or spec.is_local or p.api_key:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
|
||||||
# Match by keyword (order follows PROVIDERS registry)
|
# Match by keyword (order follows PROVIDERS registry)
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
||||||
if spec.is_oauth or p.api_key:
|
if spec.is_oauth or spec.is_local or p.api_key:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
|
||||||
|
# Fallback: configured local providers can route models without
|
||||||
|
# provider-specific keywords (for example plain "llama3.2" on Ollama).
|
||||||
|
for spec in PROVIDERS:
|
||||||
|
if not spec.is_local:
|
||||||
|
continue
|
||||||
|
p = getattr(self.providers, spec.name, None)
|
||||||
|
if p and p.api_base:
|
||||||
|
return p, spec.name
|
||||||
|
|
||||||
# Fallback: gateways first, then others (follows registry order)
|
# Fallback: gateways first, then others (follows registry order)
|
||||||
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
@@ -404,7 +430,7 @@ class Config(BaseSettings):
|
|||||||
return p.api_key if p else None
|
return p.api_key if p else None
|
||||||
|
|
||||||
def get_api_base(self, model: str | None = None) -> str | None:
|
def get_api_base(self, model: str | None = None) -> str | None:
|
||||||
"""Get API base URL for the given model. Applies default URLs for known gateways."""
|
"""Get API base URL for the given model. Applies default URLs for gateway/local providers."""
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
p, name = self._match_provider(model)
|
p, name = self._match_provider(model)
|
||||||
@@ -415,7 +441,7 @@ class Config(BaseSettings):
|
|||||||
# to avoid polluting the global litellm.api_base.
|
# to avoid polluting the global litellm.api_base.
|
||||||
if name:
|
if name:
|
||||||
spec = find_by_name(name)
|
spec = find_by_name(name)
|
||||||
if spec and spec.is_gateway and spec.default_api_base:
|
if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
|
||||||
return spec.default_api_base
|
return spec.default_api_base
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,21 @@ class LLMResponse:
|
|||||||
return len(self.tool_calls) > 0
|
return len(self.tool_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GenerationSettings:
|
||||||
|
"""Default generation parameters for LLM calls.
|
||||||
|
|
||||||
|
Stored on the provider so every call site inherits the same defaults
|
||||||
|
without having to pass temperature / max_tokens / reasoning_effort
|
||||||
|
through every layer. Individual call sites can still override by
|
||||||
|
passing explicit keyword arguments to chat() / chat_with_retry().
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 4096
|
||||||
|
reasoning_effort: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMProvider(ABC):
|
class LLMProvider(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for LLM providers.
|
Abstract base class for LLM providers.
|
||||||
@@ -75,9 +90,12 @@ class LLMProvider(ABC):
|
|||||||
"temporarily unavailable",
|
"temporarily unavailable",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
|
self.generation: GenerationSettings = GenerationSettings()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
@@ -174,11 +192,23 @@ class LLMProvider(ABC):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: object = _SENTINEL,
|
||||||
temperature: float = 0.7,
|
temperature: object = _SENTINEL,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: object = _SENTINEL,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat() with retry on transient provider failures."""
|
"""Call chat() with retry on transient provider failures.
|
||||||
|
|
||||||
|
Parameters default to ``self.generation`` when not explicitly passed,
|
||||||
|
so callers no longer need to thread temperature / max_tokens /
|
||||||
|
reasoning_effort through every layer.
|
||||||
|
"""
|
||||||
|
if max_tokens is self._SENTINEL:
|
||||||
|
max_tokens = self.generation.max_tokens
|
||||||
|
if temperature is self._SENTINEL:
|
||||||
|
temperature = self.generation.temperature
|
||||||
|
if reasoning_effort is self._SENTINEL:
|
||||||
|
reasoning_effort = self.generation.reasoning_effort
|
||||||
|
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
try:
|
try:
|
||||||
response = await self.chat(
|
response = await self.chat(
|
||||||
|
|||||||
@@ -360,6 +360,23 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
# === Ollama (local, OpenAI-compatible) ===================================
|
||||||
|
ProviderSpec(
|
||||||
|
name="ollama",
|
||||||
|
keywords=("ollama", "nemotron"),
|
||||||
|
env_key="OLLAMA_API_KEY",
|
||||||
|
display_name="Ollama",
|
||||||
|
litellm_prefix="ollama_chat", # model → ollama_chat/model
|
||||||
|
skip_prefixes=("ollama/", "ollama_chat/"),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=False,
|
||||||
|
is_local=True,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="11434",
|
||||||
|
default_api_base="http://localhost:11434",
|
||||||
|
strip_model_prefix=False,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
# === Auxiliary (not a primary LLM provider) ============================
|
# === Auxiliary (not a primary LLM provider) ============================
|
||||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
def detect_image_mime(data: bytes) -> str | None:
|
def detect_image_mime(data: bytes) -> str | None:
|
||||||
@@ -68,6 +72,104 @@ def split_message(content: str, max_len: int = 2000) -> list[str]:
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_assistant_message(
|
||||||
|
content: str | None,
|
||||||
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
|
reasoning_content: str | None = None,
|
||||||
|
thinking_blocks: list[dict] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a provider-safe assistant message with optional reasoning fields."""
|
||||||
|
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||||
|
if tool_calls:
|
||||||
|
msg["tool_calls"] = tool_calls
|
||||||
|
if reasoning_content is not None:
|
||||||
|
msg["reasoning_content"] = reasoning_content
|
||||||
|
if thinking_blocks:
|
||||||
|
msg["thinking_blocks"] = thinking_blocks
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_prompt_tokens(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Estimate prompt tokens with tiktoken."""
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
parts: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
txt = part.get("text", "")
|
||||||
|
if txt:
|
||||||
|
parts.append(txt)
|
||||||
|
if tools:
|
||||||
|
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||||
|
return len(enc.encode("\n".join(parts)))
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||||
|
"""Estimate prompt tokens contributed by one persisted message."""
|
||||||
|
content = message.get("content")
|
||||||
|
parts: list[str] = []
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
text = part.get("text", "")
|
||||||
|
if text:
|
||||||
|
parts.append(text)
|
||||||
|
else:
|
||||||
|
parts.append(json.dumps(part, ensure_ascii=False))
|
||||||
|
elif content is not None:
|
||||||
|
parts.append(json.dumps(content, ensure_ascii=False))
|
||||||
|
|
||||||
|
for key in ("name", "tool_call_id"):
|
||||||
|
value = message.get(key)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
parts.append(value)
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||||
|
|
||||||
|
payload = "\n".join(parts)
|
||||||
|
if not payload:
|
||||||
|
return 1
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return max(1, len(enc.encode(payload)))
|
||||||
|
except Exception:
|
||||||
|
return max(1, len(payload) // 4)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_prompt_tokens_chain(
|
||||||
|
provider: Any,
|
||||||
|
model: str | None,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
|
||||||
|
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
|
||||||
|
if callable(provider_counter):
|
||||||
|
try:
|
||||||
|
tokens, source = provider_counter(messages, tools, model)
|
||||||
|
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||||
|
return int(tokens), str(source or "provider_counter")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
estimated = estimate_prompt_tokens(messages, tools)
|
||||||
|
if estimated > 0:
|
||||||
|
return int(estimated), "tiktoken"
|
||||||
|
return 0, "none"
|
||||||
|
|
||||||
|
|
||||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
@@ -88,7 +190,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
|||||||
added.append(str(dest.relative_to(workspace)))
|
added.append(str(dest.relative_to(workspace)))
|
||||||
|
|
||||||
for item in tpl.iterdir():
|
for item in tpl.iterdir():
|
||||||
if item.name.endswith(".md"):
|
if item.name.endswith(".md") and not item.name.startswith("."):
|
||||||
_write(item, workspace / item.name)
|
_write(item, workspace / item.name)
|
||||||
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
||||||
_write(None, workspace / "memory" / "HISTORY.md")
|
_write(None, workspace / "memory" / "HISTORY.md")
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ classifiers = [
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"typer>=0.20.0,<1.0.0",
|
"typer>=0.20.0,<1.0.0",
|
||||||
"litellm>=1.81.5,<2.0.0",
|
"litellm>=1.82.1,<2.0.0",
|
||||||
"pydantic>=2.12.0,<3.0.0",
|
"pydantic>=2.12.0,<3.0.0",
|
||||||
"pydantic-settings>=2.12.0,<3.0.0",
|
"pydantic-settings>=2.12.0,<3.0.0",
|
||||||
"websockets>=16.0,<17.0",
|
"websockets>=16.0,<17.0",
|
||||||
@@ -44,9 +44,13 @@ dependencies = [
|
|||||||
"json-repair>=0.57.0,<1.0.0",
|
"json-repair>=0.57.0,<1.0.0",
|
||||||
"chardet>=3.0.2,<6.0.0",
|
"chardet>=3.0.2,<6.0.0",
|
||||||
"openai>=2.8.0",
|
"openai>=2.8.0",
|
||||||
|
"tiktoken>=0.12.0,<1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
wecom = [
|
||||||
|
"wecom-aibot-sdk-python @ git+https://github.com/chengyongru/wecom_aibot_sdk.git@v0.1.2",
|
||||||
|
]
|
||||||
matrix = [
|
matrix = [
|
||||||
"matrix-nio[e2e]>=0.25.2",
|
"matrix-nio[e2e]>=0.25.2",
|
||||||
"mistune>=3.0.0,<4.0.0",
|
"mistune>=3.0.0,<4.0.0",
|
||||||
@@ -68,6 +72,9 @@ nanobot = "nanobot.cli.commands:app"
|
|||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.metadata]
|
||||||
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["nanobot"]
|
packages = ["nanobot"]
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,35 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
|
|||||||
assert config.get_provider_name() == "openai_codex"
|
assert config.get_provider_name() == "openai_codex"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||||
|
config = Config()
|
||||||
|
config.agents.defaults.model = "ollama/llama3.2"
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "ollama"
|
||||||
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||||
|
config = Config()
|
||||||
|
config.agents.defaults.provider = "ollama"
|
||||||
|
config.agents.defaults.model = "llama3.2"
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "ollama"
|
||||||
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_auto_detects_ollama_from_local_api_base():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
|
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "ollama"
|
||||||
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||||
|
|
||||||
@@ -267,6 +296,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
|||||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||||
|
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "memoryWindow" in result.stdout
|
||||||
|
assert "contextWindowTokens" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||||
config_file = tmp_path / "instance" / "config.json"
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
config_file.parent.mkdir(parents=True)
|
config_file.parent.mkdir(parents=True)
|
||||||
@@ -327,6 +366,29 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
|||||||
assert seen["workspace"] == override
|
assert seen["workspace"] == override
|
||||||
assert config.workspace_path == override
|
assert config.workspace_path == override
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
|
config_file.parent.mkdir(parents=True)
|
||||||
|
config_file.write_text("{}")
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.agents.defaults.memory_window = 100
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.cli.commands._make_provider",
|
||||||
|
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||||
|
|
||||||
|
assert isinstance(result.exception, _StopGateway)
|
||||||
|
assert "memoryWindow" in result.stdout
|
||||||
|
assert "contextWindowTokens" in result.stdout
|
||||||
|
|
||||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||||
config_file = tmp_path / "instance" / "config.json"
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
config_file.parent.mkdir(parents=True)
|
config_file.parent.mkdir(parents=True)
|
||||||
|
|||||||
88
tests/test_config_migration.py
Normal file
88
tests/test_config_migration.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from nanobot.cli.commands import app
|
||||||
|
from nanobot.config.loader import load_config, save_config
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 1234,
|
||||||
|
"memoryWindow": 42,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
|
||||||
|
assert config.agents.defaults.max_tokens == 1234
|
||||||
|
assert config.agents.defaults.context_window_tokens == 65_536
|
||||||
|
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 2222,
|
||||||
|
"memoryWindow": 30,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
save_config(config, config_path)
|
||||||
|
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
defaults = saved["agents"]["defaults"]
|
||||||
|
|
||||||
|
assert defaults["maxTokens"] == 2222
|
||||||
|
assert defaults["contextWindowTokens"] == 65_536
|
||||||
|
assert "memoryWindow" not in defaults
|
||||||
|
|
||||||
|
|
||||||
|
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 3333,
|
||||||
|
"memoryWindow": 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||||
|
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "contextWindowTokens" in result.stdout
|
||||||
|
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
defaults = saved["agents"]["defaults"]
|
||||||
|
assert defaults["maxTokens"] == 3333
|
||||||
|
assert defaults["contextWindowTokens"] == 65_536
|
||||||
|
assert "memoryWindow" not in defaults
|
||||||
@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
|
|||||||
assert_messages_content(old_messages, 10, 34)
|
assert_messages_content(old_messages, 10, 34)
|
||||||
|
|
||||||
|
|
||||||
class TestConsolidationDeduplicationGuard:
|
class TestNewCommandArchival:
|
||||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@staticmethod
|
||||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
def _make_loop(tmp_path: Path):
|
||||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
context_window_tokens=1,
|
||||||
)
|
)
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
return loop
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
consolidation_calls = 0
|
|
||||||
|
|
||||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
|
||||||
nonlocal consolidation_calls
|
|
||||||
consolidation_calls += 1
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert consolidation_calls == 1, (
|
|
||||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
consolidation_calls = 0
|
|
||||||
active = 0
|
|
||||||
max_active = 0
|
|
||||||
|
|
||||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
|
||||||
nonlocal consolidation_calls, active, max_active
|
|
||||||
consolidation_calls += 1
|
|
||||||
active += 1
|
|
||||||
max_active = max(max_active, active)
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
active -= 1
|
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
|
||||||
await loop._process_message(new_msg)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert consolidation_calls == 2, (
|
|
||||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
|
||||||
)
|
|
||||||
assert max_active == 1, (
|
|
||||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
|
||||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
started = asyncio.Event()
|
|
||||||
|
|
||||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
|
||||||
started.set()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
|
|
||||||
await started.wait()
|
|
||||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
|
||||||
|
|
||||||
await asyncio.sleep(0.15)
|
|
||||||
assert len(loop._consolidation_tasks) == 0, (
|
|
||||||
"Task reference must be removed after completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""/new waits for in-flight consolidation and archives before clear."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
started = asyncio.Event()
|
|
||||||
release = asyncio.Event()
|
|
||||||
archived_count = 0
|
|
||||||
|
|
||||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
|
||||||
nonlocal archived_count
|
|
||||||
if archive_all:
|
|
||||||
archived_count = len(sess.messages)
|
|
||||||
return True
|
|
||||||
started.set()
|
|
||||||
await release.wait()
|
|
||||||
return True
|
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await started.wait()
|
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
|
||||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
|
||||||
|
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
|
||||||
|
|
||||||
release.set()
|
|
||||||
response = await pending_new
|
|
||||||
assert response is not None
|
|
||||||
assert "new session started" in response.content.lower()
|
|
||||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
|
||||||
|
|
||||||
session_after = loop.sessions.get_or_create("cli:test")
|
|
||||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||||
"""/new must keep session data if archive step reports failure."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
session.add_message("user", f"msg{i}")
|
session.add_message("user", f"msg{i}")
|
||||||
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
before_count = len(session.messages)
|
before_count = len(session.messages)
|
||||||
|
|
||||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _failing_consolidate(_messages) -> bool:
|
||||||
if archive_all:
|
return False
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
response = await loop._process_message(new_msg)
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "failed" in response.content.lower()
|
assert "failed" in response.content.lower()
|
||||||
session_after = loop.sessions.get_or_create("cli:test")
|
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
|
||||||
assert len(session_after.messages) == before_count, (
|
|
||||||
"Session must remain intact when /new archival fails"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""/new should archive only messages not yet consolidated by prior task."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
session.add_message("user", f"msg{i}")
|
session.add_message("user", f"msg{i}")
|
||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
session.last_consolidated = len(session.messages) - 3
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
started = asyncio.Event()
|
|
||||||
release = asyncio.Event()
|
|
||||||
archived_count = -1
|
archived_count = -1
|
||||||
|
|
||||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _fake_consolidate(messages) -> bool:
|
||||||
nonlocal archived_count
|
nonlocal archived_count
|
||||||
if archive_all:
|
archived_count = len(messages)
|
||||||
archived_count = len(sess.messages)
|
|
||||||
return True
|
|
||||||
|
|
||||||
started.set()
|
|
||||||
await release.wait()
|
|
||||||
sess.last_consolidated = len(sess.messages) - 3
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await started.wait()
|
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
response = await loop._process_message(new_msg)
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
assert not pending_new.done()
|
|
||||||
|
|
||||||
release.set()
|
|
||||||
response = await pending_new
|
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "new session started" in response.content.lower()
|
assert "new session started" in response.content.lower()
|
||||||
assert archived_count == 3, (
|
assert archived_count == 3
|
||||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||||
"""/new clears session and returns confirmation."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
session.add_message("user", f"msg{i}")
|
session.add_message("user", f"msg{i}")
|
||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _ok_consolidate(_messages) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
response = await loop._process_message(new_msg)
|
response = await loop._process_message(new_msg)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.dingtalk import DingTalkChannel
|
import nanobot.channels.dingtalk as dingtalk_module
|
||||||
|
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||||
from nanobot.config.schema import DingTalkConfig
|
from nanobot.config.schema import DingTalkConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -64,3 +66,46 @@ async def test_group_send_uses_group_messages_api() -> None:
|
|||||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||||
assert call["json"]["openConversationId"] == "conv123"
|
assert call["json"]["openConversationId"] == "conv123"
|
||||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
handler = NanobotDingTalkHandler(channel)
|
||||||
|
|
||||||
|
class _FakeChatbotMessage:
|
||||||
|
text = None
|
||||||
|
extensions = {"content": {"recognition": "voice transcript"}}
|
||||||
|
sender_staff_id = "user1"
|
||||||
|
sender_id = "fallback-user"
|
||||||
|
sender_nick = "Alice"
|
||||||
|
message_type = "audio"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(_data):
|
||||||
|
return _FakeChatbotMessage()
|
||||||
|
|
||||||
|
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||||
|
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||||
|
|
||||||
|
status, body = await handler.process(
|
||||||
|
SimpleNamespace(
|
||||||
|
data={
|
||||||
|
"conversationType": "2",
|
||||||
|
"conversationId": "conv123",
|
||||||
|
"text": {"content": ""},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(*list(channel._background_tasks))
|
||||||
|
msg = await bus.consume_inbound()
|
||||||
|
|
||||||
|
assert (status, body) == ("OK", "OK")
|
||||||
|
assert msg.content == "voice transcript"
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
assert msg.chat_id == "group:conv123"
|
||||||
|
|||||||
190
tests/test_loop_consolidation_tokens.py
Normal file
190
tests/test_loop_consolidation_tokens.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
import nanobot.agent.memory as memory_module
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||||
|
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||||
|
|
||||||
|
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||||
|
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||||
|
assert session.last_consolidated == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||||
|
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||||
|
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||||
|
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
def mock_estimate(_session):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
return (500, "test")
|
||||||
|
if call_count[0] == 2:
|
||||||
|
return (300, "test")
|
||||||
|
return (80, "test")
|
||||||
|
|
||||||
|
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||||
|
|
||||||
|
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||||
|
assert session.last_consolidated == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||||
|
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||||
|
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||||
|
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def mock_estimate(_session):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
return (500, "test")
|
||||||
|
if call_count[0] == 2:
|
||||||
|
return (150, "test")
|
||||||
|
return (80, "test")
|
||||||
|
|
||||||
|
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||||
|
|
||||||
|
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||||
|
assert session.last_consolidated == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||||
|
order: list[str] = []
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
|
||||||
|
async def track_consolidate(messages):
|
||||||
|
order.append("consolidate")
|
||||||
|
return True
|
||||||
|
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
async def track_llm(*args, **kwargs):
|
||||||
|
order.append("llm")
|
||||||
|
return LLMResponse(content="ok", tool_calls=[])
|
||||||
|
loop.provider.chat_with_retry = track_llm
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
def mock_estimate(_session):
|
||||||
|
call_count[0] += 1
|
||||||
|
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||||
|
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
assert "consolidate" in order
|
||||||
|
assert "llm" in order
|
||||||
|
assert order.index("consolidate") < order.index("llm")
|
||||||
@@ -7,7 +7,7 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -15,15 +15,12 @@ from nanobot.agent.memory import MemoryStore
|
|||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
def _make_messages(message_count: int = 30):
|
||||||
"""Create a mock session with messages."""
|
"""Create a list of mock messages."""
|
||||||
session = MagicMock()
|
return [
|
||||||
session.messages = [
|
|
||||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||||
for i in range(message_count)
|
for i in range(message_count)
|
||||||
]
|
]
|
||||||
session.last_consolidated = 0
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tool_response(history_entry, memory_update):
|
def _make_tool_response(history_entry, memory_update):
|
||||||
@@ -74,9 +71,9 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert store.history_file.exists()
|
assert store.history_file.exists()
|
||||||
@@ -95,9 +92,9 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert store.history_file.exists()
|
assert store.history_file.exists()
|
||||||
@@ -131,9 +128,9 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert "User discussed testing." in store.history_file.read_text()
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
@@ -147,22 +144,22 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||||
)
|
)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
assert not store.history_file.exists()
|
assert not store.history_file.exists()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||||
"""Consolidation should be a no-op when messages < keep_count."""
|
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=10)
|
messages: list[dict] = []
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
provider.chat.assert_not_called()
|
provider.chat.assert_not_called()
|
||||||
@@ -189,9 +186,9 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert "User discussed testing." in store.history_file.read_text()
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
@@ -215,9 +212,9 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@@ -239,9 +236,9 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
provider.chat_with_retry = provider.chat
|
provider.chat_with_retry = provider.chat
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@@ -255,7 +252,7 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
memory_update="# Memory\nUser likes testing.",
|
memory_update="# Memory\nUser likes testing.",
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
session = _make_session(message_count=60)
|
messages = _make_messages(message_count=60)
|
||||||
delays: list[int] = []
|
delays: list[int] = []
|
||||||
|
|
||||||
async def _fake_sleep(delay: int) -> None:
|
async def _fake_sleep(delay: int) -> None:
|
||||||
@@ -263,8 +260,31 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
|
|
||||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert provider.calls == 2
|
assert provider.calls == 2
|
||||||
assert delays == [1]
|
assert delays == [1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
|
||||||
|
"""Consolidation no longer passes generation params — the provider owns them."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry="[2026-01-01] User discussed testing.",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
provider.chat_with_retry.assert_awaited_once()
|
||||||
|
_, kwargs = provider.chat_with_retry.await_args
|
||||||
|
assert kwargs["model"] == "test-model"
|
||||||
|
assert "temperature" not in kwargs
|
||||||
|
assert "max_tokens" not in kwargs
|
||||||
|
assert "reasoning_effort" not in kwargs
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
|
||||||
|
|
||||||
class TestMessageToolSuppressLogic:
|
class TestMessageToolSuppressLogic:
|
||||||
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
LLMResponse(content="", tool_calls=[tool_call]),
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
LLMResponse(content="Done", tool_calls=[]),
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
])
|
])
|
||||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
sent: list[OutboundMessage] = []
|
sent: list[OutboundMessage] = []
|
||||||
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
LLMResponse(content="", tool_calls=[tool_call]),
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||||
])
|
])
|
||||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
sent: list[OutboundMessage] = []
|
sent: list[OutboundMessage] = []
|
||||||
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||||
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
),
|
),
|
||||||
LLMResponse(content="Done", tool_calls=[]),
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
])
|
])
|
||||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
loop.tools.execute = AsyncMock(return_value="ok")
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
class ScriptedProvider(LLMProvider):
|
class ScriptedProvider(LLMProvider):
|
||||||
@@ -10,9 +10,11 @@ class ScriptedProvider(LLMProvider):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._responses = list(responses)
|
self._responses = list(responses)
|
||||||
self.calls = 0
|
self.calls = 0
|
||||||
|
self.last_kwargs: dict = {}
|
||||||
|
|
||||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
self.calls += 1
|
self.calls += 1
|
||||||
|
self.last_kwargs = kwargs
|
||||||
response = self._responses.pop(0)
|
response = self._responses.pop(0)
|
||||||
if isinstance(response, BaseException):
|
if isinstance(response, BaseException):
|
||||||
raise response
|
raise response
|
||||||
@@ -90,3 +92,34 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
|||||||
|
|
||||||
with pytest.raises(asyncio.CancelledError):
|
with pytest.raises(asyncio.CancelledError):
|
||||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||||
|
"""When callers omit generation params, provider.generation defaults are used."""
|
||||||
|
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||||
|
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||||
|
|
||||||
|
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert provider.last_kwargs["temperature"] == 0.2
|
||||||
|
assert provider.last_kwargs["max_tokens"] == 321
|
||||||
|
assert provider.last_kwargs["reasoning_effort"] == "high"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||||
|
"""Explicit kwargs should override provider.generation defaults."""
|
||||||
|
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||||
|
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||||
|
|
||||||
|
await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
temperature=0.9,
|
||||||
|
max_tokens=9999,
|
||||||
|
reasoning_effort="low",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.last_kwargs["temperature"] == 0.9
|
||||||
|
assert provider.last_kwargs["max_tokens"] == 9999
|
||||||
|
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||||
|
|||||||
@@ -165,3 +165,46 @@ class TestSubagentCancellation:
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def scripted_chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
reasoning_content="hidden reasoning",
|
||||||
|
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[])
|
||||||
|
provider.chat_with_retry = scripted_chat_with_retry
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||||
|
|
||||||
|
async def fake_execute(self, name, arguments):
|
||||||
|
return "tool result"
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||||
|
|
||||||
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
|
assistant_messages = [
|
||||||
|
msg for msg in captured_second_call
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||||
|
]
|
||||||
|
assert len(assistant_messages) == 1
|
||||||
|
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||||
|
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||||
|
|||||||
Reference in New Issue
Block a user