Merge remote-tracking branch 'origin/main' into pr-949
This commit is contained in:
72
README.md
72
README.md
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,806 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,862 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
@@ -34,6 +34,7 @@
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Earlier news</summary>
|
<summary>Earlier news</summary>
|
||||||
|
|
||||||
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
||||||
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
||||||
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
||||||
@@ -43,6 +44,7 @@
|
|||||||
- **2026-02-04** 🚀 Released **v0.1.3.post4** with multi-provider & Docker support! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post4) for details.
|
- **2026-02-04** 🚀 Released **v0.1.3.post4** with multi-provider & Docker support! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post4) for details.
|
||||||
- **2026-02-03** ⚡ Integrated vLLM for local LLM support and improved natural language task scheduling!
|
- **2026-02-03** ⚡ Integrated vLLM for local LLM support and improved natural language task scheduling!
|
||||||
- **2026-02-02** 🎉 nanobot officially launched! Welcome to try 🐈 nanobot!
|
- **2026-02-02** 🎉 nanobot officially launched! Welcome to try 🐈 nanobot!
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Key Features of nanobot:
|
## Key Features of nanobot:
|
||||||
@@ -774,6 +776,21 @@ Two transport modes are supported:
|
|||||||
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
|
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
|
||||||
| **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) |
|
| **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) |
|
||||||
|
|
||||||
|
Use `toolTimeout` to override the default 30s per-call timeout for slow servers:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": {
|
||||||
|
"mcpServers": {
|
||||||
|
"my-slow-server": {
|
||||||
|
"url": "https://example.com/mcp/",
|
||||||
|
"toolTimeout": 120
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
||||||
|
|
||||||
|
|
||||||
@@ -863,6 +880,59 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!"
|
|||||||
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
|
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🐧 Linux Service
|
||||||
|
|
||||||
|
Run the gateway as a systemd user service so it starts automatically and restarts on failure.
|
||||||
|
|
||||||
|
**1. Find the nanobot binary path:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
which nanobot # e.g. /home/user/.local/bin/nanobot
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Create the service file** at `~/.config/systemd/user/nanobot-gateway.service` (replace `ExecStart` path if needed):
|
||||||
|
|
||||||
|
```ini
|
||||||
|
[Unit]
|
||||||
|
Description=Nanobot Gateway
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
ExecStart=%h/.local/bin/nanobot gateway
|
||||||
|
Restart=always
|
||||||
|
RestartSec=10
|
||||||
|
NoNewPrivileges=yes
|
||||||
|
ProtectSystem=strict
|
||||||
|
ReadWritePaths=%h
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=default.target
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Enable and start:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl --user daemon-reload
|
||||||
|
systemctl --user enable --now nanobot-gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
**Common operations:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl --user status nanobot-gateway # check status
|
||||||
|
systemctl --user restart nanobot-gateway # restart after config changes
|
||||||
|
journalctl --user -u nanobot-gateway -f # follow logs
|
||||||
|
```
|
||||||
|
|
||||||
|
If you edit the `.service` file itself, run `systemctl --user daemon-reload` before restarting.
|
||||||
|
|
||||||
|
> **Note:** User services only run while you are logged in. To keep the gateway running after logout, enable lingering:
|
||||||
|
>
|
||||||
|
> ```bash
|
||||||
|
> loginctl enable-linger $USER
|
||||||
|
> ```
|
||||||
|
|
||||||
## 📁 Project Structure
|
## 📁 Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -82,12 +82,7 @@ Skills with available="false" need dependencies installed first - you can try in
|
|||||||
|
|
||||||
return f"""# nanobot 🐈
|
return f"""# nanobot 🐈
|
||||||
|
|
||||||
You are nanobot, a helpful AI assistant. You have access to tools that allow you to:
|
You are nanobot, a helpful AI assistant.
|
||||||
- Read, write, and edit files
|
|
||||||
- Execute shell commands
|
|
||||||
- Search the web and fetch web pages
|
|
||||||
- Send messages to users on chat channels
|
|
||||||
- Spawn subagents for complex background tasks
|
|
||||||
|
|
||||||
## Current Time
|
## Current Time
|
||||||
{now} ({tz})
|
{now} ({tz})
|
||||||
@@ -236,7 +231,7 @@ To recall past events, grep {workspace_path}/memory/HISTORY.md"""
|
|||||||
msg["tool_calls"] = tool_calls
|
msg["tool_calls"] = tool_calls
|
||||||
|
|
||||||
# Include reasoning content when provided (required by some thinking models)
|
# Include reasoning content when provided (required by some thinking models)
|
||||||
if reasoning_content:
|
if reasoning_content is not None:
|
||||||
msg["reasoning_content"] = reasoning_content
|
msg["reasoning_content"] = reasoning_content
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -95,6 +95,8 @@ class AgentLoop:
|
|||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||||
|
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||||
|
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@@ -194,7 +196,8 @@ class AgentLoop:
|
|||||||
clean = self._strip_think(response.content)
|
clean = self._strip_think(response.content)
|
||||||
if clean:
|
if clean:
|
||||||
await on_progress(clean)
|
await on_progress(clean)
|
||||||
await on_progress(self._tool_hint(response.tool_calls))
|
else:
|
||||||
|
await on_progress(self._tool_hint(response.tool_calls))
|
||||||
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
{
|
{
|
||||||
@@ -270,6 +273,18 @@ class AgentLoop:
|
|||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
|
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
|
lock = self._consolidation_locks.get(session_key)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._consolidation_locks[session_key] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
|
||||||
|
"""Drop lock entry if no longer in use."""
|
||||||
|
if not lock.locked():
|
||||||
|
self._consolidation_locks.pop(session_key, None)
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
@@ -305,33 +320,55 @@ class AgentLoop:
|
|||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
messages_to_archive = session.messages.copy()
|
lock = self._get_consolidation_lock(session.key)
|
||||||
|
self._consolidating.add(session.key)
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
snapshot = session.messages[session.last_consolidated:]
|
||||||
|
if snapshot:
|
||||||
|
temp = Session(key=session.key)
|
||||||
|
temp.messages = list(snapshot)
|
||||||
|
if not await self._consolidate_memory(temp, archive_all=True):
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("/new archival failed for {}", session.key)
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._consolidating.discard(session.key)
|
||||||
|
self._prune_consolidation_lock(session.key, lock)
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self.sessions.invalidate(session.key)
|
self.sessions.invalidate(session.key)
|
||||||
|
|
||||||
async def _consolidate_and_cleanup():
|
|
||||||
temp = Session(key=session.key)
|
|
||||||
temp.messages = messages_to_archive
|
|
||||||
await self._consolidate_memory(temp, archive_all=True)
|
|
||||||
|
|
||||||
asyncio.create_task(_consolidate_and_cleanup())
|
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="New session started. Memory consolidation in progress.")
|
content="New session started.")
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||||
|
|
||||||
if len(session.messages) > self.memory_window and session.key not in self._consolidating:
|
if len(session.messages) > self.memory_window and session.key not in self._consolidating:
|
||||||
self._consolidating.add(session.key)
|
self._consolidating.add(session.key)
|
||||||
|
lock = self._get_consolidation_lock(session.key)
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
async def _consolidate_and_unlock():
|
||||||
try:
|
try:
|
||||||
await self._consolidate_memory(session)
|
async with lock:
|
||||||
|
await self._consolidate_memory(session)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
|
self._prune_consolidation_lock(session.key, lock)
|
||||||
|
_task = asyncio.current_task()
|
||||||
|
if _task is not None:
|
||||||
|
self._consolidation_tasks.discard(_task)
|
||||||
|
|
||||||
asyncio.create_task(_consolidate_and_unlock())
|
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||||
|
self._consolidation_tasks.add(_task)
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
@@ -376,9 +413,9 @@ class AgentLoop:
|
|||||||
metadata=msg.metadata or {},
|
metadata=msg.metadata or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||||
"""Delegate to MemoryStore.consolidate()."""
|
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
||||||
await MemoryStore(self.workspace).consolidate(
|
return await MemoryStore(self.workspace).consolidate(
|
||||||
session, self.provider, self.model,
|
session, self.provider, self.model,
|
||||||
archive_all=archive_all, memory_window=self.memory_window,
|
archive_all=archive_all, memory_window=self.memory_window,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -74,8 +74,11 @@ class MemoryStore:
|
|||||||
*,
|
*,
|
||||||
archive_all: bool = False,
|
archive_all: bool = False,
|
||||||
memory_window: int = 50,
|
memory_window: int = 50,
|
||||||
) -> None:
|
) -> bool:
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call."""
|
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||||
|
|
||||||
|
Returns True on success (including no-op), False on failure.
|
||||||
|
"""
|
||||||
if archive_all:
|
if archive_all:
|
||||||
old_messages = session.messages
|
old_messages = session.messages
|
||||||
keep_count = 0
|
keep_count = 0
|
||||||
@@ -83,12 +86,12 @@ class MemoryStore:
|
|||||||
else:
|
else:
|
||||||
keep_count = memory_window // 2
|
keep_count = memory_window // 2
|
||||||
if len(session.messages) <= keep_count:
|
if len(session.messages) <= keep_count:
|
||||||
return
|
return True
|
||||||
if len(session.messages) - session.last_consolidated <= 0:
|
if len(session.messages) - session.last_consolidated <= 0:
|
||||||
return
|
return True
|
||||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||||
if not old_messages:
|
if not old_messages:
|
||||||
return
|
return True
|
||||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
@@ -119,7 +122,7 @@ class MemoryStore:
|
|||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||||
return
|
return False
|
||||||
|
|
||||||
args = response.tool_calls[0].arguments
|
args = response.tool_calls[0].arguments
|
||||||
if entry := args.get("history_entry"):
|
if entry := args.get("history_entry"):
|
||||||
@@ -134,5 +137,7 @@ class MemoryStore:
|
|||||||
|
|
||||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||||
except Exception as e:
|
return True
|
||||||
logger.error("Memory consolidation failed: {}", e)
|
except Exception:
|
||||||
|
logger.exception("Memory consolidation failed")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -13,8 +13,11 @@ def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path |
|
|||||||
if not p.is_absolute() and workspace:
|
if not p.is_absolute() and workspace:
|
||||||
p = workspace / p
|
p = workspace / p
|
||||||
resolved = p.resolve()
|
resolved = p.resolve()
|
||||||
if allowed_dir and not str(resolved).startswith(str(allowed_dir.resolve())):
|
if allowed_dir:
|
||||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
try:
|
||||||
|
resolved.relative_to(allowed_dir.resolve())
|
||||||
|
except ValueError:
|
||||||
|
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -13,12 +14,13 @@ from nanobot.agent.tools.registry import ToolRegistry
|
|||||||
class MCPToolWrapper(Tool):
|
class MCPToolWrapper(Tool):
|
||||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||||
|
|
||||||
def __init__(self, session, server_name: str, tool_def):
|
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||||
self._session = session
|
self._session = session
|
||||||
self._original_name = tool_def.name
|
self._original_name = tool_def.name
|
||||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||||
self._description = tool_def.description or tool_def.name
|
self._description = tool_def.description or tool_def.name
|
||||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||||
|
self._tool_timeout = tool_timeout
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -34,7 +36,14 @@ class MCPToolWrapper(Tool):
|
|||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> str:
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
from mcp import types
|
from mcp import types
|
||||||
result = await self._session.call_tool(self._original_name, arguments=kwargs)
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||||
|
timeout=self._tool_timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||||
|
return f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||||
parts = []
|
parts = []
|
||||||
for block in result.content:
|
for block in result.content:
|
||||||
if isinstance(block, types.TextContent):
|
if isinstance(block, types.TextContent):
|
||||||
@@ -83,7 +92,7 @@ async def connect_mcp_servers(
|
|||||||
|
|
||||||
tools = await session.list_tools()
|
tools = await session.list_tools()
|
||||||
for tool_def in tools.tools:
|
for tool_def in tools.tools:
|
||||||
wrapper = MCPToolWrapper(session, name, tool_def)
|
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
|
||||||
registry.register(wrapper)
|
registry.register(wrapper)
|
||||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||||
|
|
||||||
|
|||||||
@@ -304,7 +304,8 @@ class EmailChannel(BaseChannel):
|
|||||||
self._processed_uids.add(uid)
|
self._processed_uids.add(uid)
|
||||||
# mark_seen is the primary dedup; this set is a safety net
|
# mark_seen is the primary dedup; this set is a safety net
|
||||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||||
self._processed_uids.clear()
|
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||||
|
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||||
|
|
||||||
if mark_seen:
|
if mark_seen:
|
||||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||||
|
|||||||
@@ -503,18 +503,29 @@ class FeishuChannel(BaseChannel):
|
|||||||
logger.error("Error downloading image {}: {}", image_key, e)
|
logger.error("Error downloading image {}: {}", image_key, e)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def _download_file_sync(self, file_key: str) -> tuple[bytes | None, str | None]:
|
def _download_file_sync(
|
||||||
"""Download a file from Feishu by file_key."""
|
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||||
try:
|
try:
|
||||||
request = GetFileRequest.builder().file_key(file_key).build()
|
request = (
|
||||||
response = self._client.im.v1.file.get(request)
|
GetMessageResourceRequest.builder()
|
||||||
|
.message_id(message_id)
|
||||||
|
.file_key(file_key)
|
||||||
|
.type(resource_type)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
response = self._client.im.v1.message_resource.get(request)
|
||||||
if response.success():
|
if response.success():
|
||||||
return response.file, response.file_name
|
file_data = response.file
|
||||||
|
if hasattr(file_data, "read"):
|
||||||
|
file_data = file_data.read()
|
||||||
|
return file_data, response.file_name
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to download file: code={}, msg={}", response.code, response.msg)
|
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||||
return None, None
|
return None, None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error downloading file {}: {}", file_key, e)
|
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def _download_and_save_media(
|
async def _download_and_save_media(
|
||||||
@@ -544,14 +555,14 @@ class FeishuChannel(BaseChannel):
|
|||||||
if not filename:
|
if not filename:
|
||||||
filename = f"{image_key[:16]}.jpg"
|
filename = f"{image_key[:16]}.jpg"
|
||||||
|
|
||||||
elif msg_type in ("audio", "file"):
|
elif msg_type in ("audio", "file", "media"):
|
||||||
file_key = content_json.get("file_key")
|
file_key = content_json.get("file_key")
|
||||||
if file_key:
|
if file_key and message_id:
|
||||||
data, filename = await loop.run_in_executor(
|
data, filename = await loop.run_in_executor(
|
||||||
None, self._download_file_sync, file_key
|
None, self._download_file_sync, message_id, file_key, msg_type
|
||||||
)
|
)
|
||||||
if not filename:
|
if not filename:
|
||||||
ext = ".opus" if msg_type == "audio" else ""
|
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
||||||
filename = f"{file_key[:16]}{ext}"
|
filename = f"{file_key[:16]}{ext}"
|
||||||
|
|
||||||
if data and filename:
|
if data and filename:
|
||||||
@@ -684,7 +695,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
if text:
|
if text:
|
||||||
content_parts.append(text)
|
content_parts.append(text)
|
||||||
|
|
||||||
elif msg_type in ("image", "audio", "file"):
|
elif msg_type in ("image", "audio", "file", "media"):
|
||||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||||
if file_path:
|
if file_path:
|
||||||
media_paths.append(file_path)
|
media_paths.append(file_path)
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ class QQChannel(BaseChannel):
|
|||||||
self.config: QQConfig = config
|
self.config: QQConfig = config
|
||||||
self._client: "botpy.Client | None" = None
|
self._client: "botpy.Client | None" = None
|
||||||
self._processed_ids: deque = deque(maxlen=1000)
|
self._processed_ids: deque = deque(maxlen=1000)
|
||||||
self._bot_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the QQ bot."""
|
"""Start the QQ bot."""
|
||||||
@@ -71,8 +70,8 @@ class QQChannel(BaseChannel):
|
|||||||
BotClass = _make_bot_class(self)
|
BotClass = _make_bot_class(self)
|
||||||
self._client = BotClass()
|
self._client = BotClass()
|
||||||
|
|
||||||
self._bot_task = asyncio.create_task(self._run_bot())
|
|
||||||
logger.info("QQ bot started (C2C private message)")
|
logger.info("QQ bot started (C2C private message)")
|
||||||
|
await self._run_bot()
|
||||||
|
|
||||||
async def _run_bot(self) -> None:
|
async def _run_bot(self) -> None:
|
||||||
"""Run the bot connection with auto-reconnect."""
|
"""Run the bot connection with auto-reconnect."""
|
||||||
@@ -88,11 +87,10 @@ class QQChannel(BaseChannel):
|
|||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the QQ bot."""
|
"""Stop the QQ bot."""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._bot_task:
|
if self._client:
|
||||||
self._bot_task.cancel()
|
|
||||||
try:
|
try:
|
||||||
await self._bot_task
|
await self._client.close()
|
||||||
except asyncio.CancelledError:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info("QQ bot stopped")
|
logger.info("QQ bot stopped")
|
||||||
|
|
||||||
@@ -130,5 +128,5 @@ class QQChannel(BaseChannel):
|
|||||||
content=content,
|
content=content,
|
||||||
metadata={"message_id": data.id},
|
metadata={"message_id": data.id},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Error handling QQ message: {}", e)
|
logger.exception("Error handling QQ message")
|
||||||
|
|||||||
@@ -179,18 +179,21 @@ class SlackChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Slack reactions_add failed: {}", e)
|
logger.debug("Slack reactions_add failed: {}", e)
|
||||||
|
|
||||||
await self._handle_message(
|
try:
|
||||||
sender_id=sender_id,
|
await self._handle_message(
|
||||||
chat_id=chat_id,
|
sender_id=sender_id,
|
||||||
content=text,
|
chat_id=chat_id,
|
||||||
metadata={
|
content=text,
|
||||||
"slack": {
|
metadata={
|
||||||
"event": event,
|
"slack": {
|
||||||
"thread_ts": thread_ts,
|
"event": event,
|
||||||
"channel_type": channel_type,
|
"thread_ts": thread_ts,
|
||||||
}
|
"channel_type": channel_type,
|
||||||
},
|
}
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error handling Slack message from {}", sender_id)
|
||||||
|
|
||||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||||
if channel_type == "im":
|
if channel_type == "im":
|
||||||
|
|||||||
@@ -668,6 +668,33 @@ def channels_status():
|
|||||||
slack_config
|
slack_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DingTalk
|
||||||
|
dt = config.channels.dingtalk
|
||||||
|
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"DingTalk",
|
||||||
|
"✓" if dt.enabled else "✗",
|
||||||
|
dt_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# QQ
|
||||||
|
qq = config.channels.qq
|
||||||
|
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"QQ",
|
||||||
|
"✓" if qq.enabled else "✗",
|
||||||
|
qq_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Email
|
||||||
|
em = config.channels.email
|
||||||
|
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"Email",
|
||||||
|
"✓" if em.enabled else "✗",
|
||||||
|
em_config
|
||||||
|
)
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ class MCPServerConfig(Base):
|
|||||||
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||||
url: str = "" # HTTP: streamable HTTP endpoint URL
|
url: str = "" # HTTP: streamable HTTP endpoint URL
|
||||||
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
|
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
|
||||||
|
tool_timeout: int = 30 # Seconds before a tool call is cancelled
|
||||||
|
|
||||||
|
|
||||||
class ToolsConfig(Base):
|
class ToolsConfig(Base):
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class CustomProvider(LLMProvider):
|
|||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
||||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||||
reasoning_content=getattr(msg, "reasoning_content", None),
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"total_tokens": response.usage.total_tokens,
|
"total_tokens": response.usage.total_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
reasoning_content = getattr(message, "reasoning_content", None)
|
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("volcengine", "volces", "ark"),
|
keywords=("volcengine", "volces", "ark"),
|
||||||
env_key="OPENAI_API_KEY",
|
env_key="OPENAI_API_KEY",
|
||||||
display_name="VolcEngine",
|
display_name="VolcEngine",
|
||||||
litellm_prefix="openai",
|
litellm_prefix="volcengine",
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Session management for conversation history."""
|
"""Session management for conversation history."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -108,9 +109,11 @@ class SessionManager:
|
|||||||
if not path.exists():
|
if not path.exists():
|
||||||
legacy_path = self._get_legacy_session_path(key)
|
legacy_path = self._get_legacy_session_path(key)
|
||||||
if legacy_path.exists():
|
if legacy_path.exists():
|
||||||
import shutil
|
try:
|
||||||
shutil.move(str(legacy_path), str(path))
|
shutil.move(str(legacy_path), str(path))
|
||||||
logger.info("Migrated session {} from legacy path", key)
|
logger.info("Migrated session {} from legacy path", key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to migrate session {}", key)
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""Test session management with cache-friendly message handling."""
|
"""Test session management with cache-friendly message handling."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
@@ -475,3 +478,351 @@ class TestEmptyAndBoundarySessions:
|
|||||||
expected_count = 60 - KEEP_COUNT - 10
|
expected_count = 60 - KEEP_COUNT - 10
|
||||||
assert len(old_messages) == expected_count
|
assert len(old_messages) == expected_count
|
||||||
assert_messages_content(old_messages, 10, 34)
|
assert_messages_content(old_messages, 10, 34)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsolidationDeduplicationGuard:
|
||||||
|
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||||
|
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
consolidation_calls = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
nonlocal consolidation_calls
|
||||||
|
consolidation_calls += 1
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert consolidation_calls == 1, (
|
||||||
|
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
consolidation_calls = 0
|
||||||
|
active = 0
|
||||||
|
max_active = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
nonlocal consolidation_calls, active, max_active
|
||||||
|
consolidation_calls += 1
|
||||||
|
active += 1
|
||||||
|
max_active = max(max_active, active)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
active -= 1
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
await loop._process_message(new_msg)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert consolidation_calls == 2, (
|
||||||
|
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||||
|
)
|
||||||
|
assert max_active == 1, (
|
||||||
|
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||||
|
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
await started.wait()
|
||||||
|
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||||
|
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
assert len(loop._consolidation_tasks) == 0, (
|
||||||
|
"Task reference must be removed after completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new waits for in-flight consolidation and archives before clear."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
archived_count = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
nonlocal archived_count
|
||||||
|
if archive_all:
|
||||||
|
archived_count = len(sess.messages)
|
||||||
|
return True
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||||
|
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
response = await pending_new
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||||
|
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||||
|
"""/new must keep session data if archive step reports failure."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(5):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
before_count = len(session.messages)
|
||||||
|
|
||||||
|
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
if archive_all:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "failed" in response.content.lower()
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert len(session_after.messages) == before_count, (
|
||||||
|
"Session must remain intact when /new archival fails"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new should archive only messages not yet consolidated by prior task."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
archived_count = -1
|
||||||
|
|
||||||
|
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
nonlocal archived_count
|
||||||
|
if archive_all:
|
||||||
|
archived_count = len(sess.messages)
|
||||||
|
return True
|
||||||
|
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
sess.last_consolidated = len(sess.messages) - 3
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
assert not pending_new.done()
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
response = await pending_new
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert archived_count == 3, (
|
||||||
|
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new should remove lock entry for fully invalidated session key."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(3):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
# Ensure lock exists before /new.
|
||||||
|
_ = loop._get_consolidation_lock(session.key)
|
||||||
|
assert session.key in loop._consolidation_locks
|
||||||
|
|
||||||
|
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert session.key not in loop._consolidation_locks
|
||||||
|
|||||||
Reference in New Issue
Block a user