Merge remote-tracking branch 'origin/main' into pr-986

This commit is contained in:
Re-bin
2026-02-22 18:13:37 +00:00
13 changed files with 535 additions and 56 deletions

View File

@@ -16,7 +16,7 @@
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. ⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
📏 Real-time line count: **3,806 lines** (run `bash core_agent_lines.sh` to verify anytime) 📏 Real-time line count: **3,862 lines** (run `bash core_agent_lines.sh` to verify anytime)
## 📢 News ## 📢 News
@@ -776,6 +776,21 @@ Two transport modes are supported:
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` | | **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
| **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) | | **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) |
Use `toolTimeout` to override the default 30s per-call timeout for slow servers:
```json
{
"tools": {
"mcpServers": {
"my-slow-server": {
"url": "https://example.com/mcp/",
"toolTimeout": 120
}
}
}
}
```
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed. MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
@@ -865,6 +880,59 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!"
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
``` ```
## 🐧 Linux Service
Run the gateway as a systemd user service so it starts automatically and restarts on failure.
**1. Find the nanobot binary path:**
```bash
which nanobot # e.g. /home/user/.local/bin/nanobot
```
**2. Create the service file** at `~/.config/systemd/user/nanobot-gateway.service` (replace `ExecStart` path if needed):
```ini
[Unit]
Description=Nanobot Gateway
After=network.target
[Service]
Type=simple
ExecStart=%h/.local/bin/nanobot gateway
Restart=always
RestartSec=10
NoNewPrivileges=yes
ProtectSystem=strict
ReadWritePaths=%h
[Install]
WantedBy=default.target
```
**3. Enable and start:**
```bash
systemctl --user daemon-reload
systemctl --user enable --now nanobot-gateway
```
**Common operations:**
```bash
systemctl --user status nanobot-gateway # check status
systemctl --user restart nanobot-gateway # restart after config changes
journalctl --user -u nanobot-gateway -f # follow logs
```
If you edit the `.service` file itself, run `systemctl --user daemon-reload` before restarting.
> **Note:** User services only run while you are logged in. To keep the gateway running after logout, enable lingering:
>
> ```bash
> loginctl enable-linger $USER
> ```
## 📁 Project Structure ## 📁 Project Structure
``` ```

View File

@@ -7,7 +7,7 @@ import json
import re import re
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Awaitable, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger from loguru import logger
@@ -95,6 +95,8 @@ class AgentLoop:
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
@@ -194,7 +196,8 @@ class AgentLoop:
clean = self._strip_think(response.content) clean = self._strip_think(response.content)
if clean: if clean:
await on_progress(clean) await on_progress(clean)
await on_progress(self._tool_hint(response.tool_calls)) else:
await on_progress(self._tool_hint(response.tool_calls))
tool_call_dicts = [ tool_call_dicts = [
{ {
@@ -270,6 +273,18 @@ class AgentLoop:
self._running = False self._running = False
logger.info("Agent loop stopping") logger.info("Agent loop stopping")
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
lock = self._consolidation_locks.get(session_key)
if lock is None:
lock = asyncio.Lock()
self._consolidation_locks[session_key] = lock
return lock
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
"""Drop lock entry if no longer in use."""
if not lock.locked():
self._consolidation_locks.pop(session_key, None)
async def _process_message( async def _process_message(
self, self,
msg: InboundMessage, msg: InboundMessage,
@@ -305,33 +320,55 @@ class AgentLoop:
# Slash commands # Slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
messages_to_archive = session.messages.copy() lock = self._get_consolidation_lock(session.key)
self._consolidating.add(session.key)
try:
async with lock:
snapshot = session.messages[session.last_consolidated:]
if snapshot:
temp = Session(key=session.key)
temp.messages = list(snapshot)
if not await self._consolidate_memory(temp, archive_all=True):
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
except Exception:
logger.exception("/new archival failed for {}", session.key)
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
finally:
self._consolidating.discard(session.key)
self._prune_consolidation_lock(session.key, lock)
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
self.sessions.invalidate(session.key) self.sessions.invalidate(session.key)
async def _consolidate_and_cleanup():
temp = Session(key=session.key)
temp.messages = messages_to_archive
await self._consolidate_memory(temp, archive_all=True)
asyncio.create_task(_consolidate_and_cleanup())
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.") content="New session started.")
if cmd == "/help": if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands") content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
if len(session.messages) > self.memory_window and session.key not in self._consolidating: if len(session.messages) > self.memory_window and session.key not in self._consolidating:
self._consolidating.add(session.key) self._consolidating.add(session.key)
lock = self._get_consolidation_lock(session.key)
async def _consolidate_and_unlock(): async def _consolidate_and_unlock():
try: try:
await self._consolidate_memory(session) async with lock:
await self._consolidate_memory(session)
finally: finally:
self._consolidating.discard(session.key) self._consolidating.discard(session.key)
self._prune_consolidation_lock(session.key, lock)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)
asyncio.create_task(_consolidate_and_unlock()) _task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(_task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"): if message_tool := self.tools.get("message"):
@@ -376,9 +413,9 @@ class AgentLoop:
metadata=msg.metadata or {}, metadata=msg.metadata or {},
) )
async def _consolidate_memory(self, session, archive_all: bool = False) -> None: async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Delegate to MemoryStore.consolidate().""" """Delegate to MemoryStore.consolidate(). Returns True on success."""
await MemoryStore(self.workspace).consolidate( return await MemoryStore(self.workspace).consolidate(
session, self.provider, self.model, session, self.provider, self.model,
archive_all=archive_all, memory_window=self.memory_window, archive_all=archive_all, memory_window=self.memory_window,
) )

View File

@@ -74,8 +74,11 @@ class MemoryStore:
*, *,
archive_all: bool = False, archive_all: bool = False,
memory_window: int = 50, memory_window: int = 50,
) -> None: ) -> bool:
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.""" """Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
Returns True on success (including no-op), False on failure.
"""
if archive_all: if archive_all:
old_messages = session.messages old_messages = session.messages
keep_count = 0 keep_count = 0
@@ -83,12 +86,12 @@ class MemoryStore:
else: else:
keep_count = memory_window // 2 keep_count = memory_window // 2
if len(session.messages) <= keep_count: if len(session.messages) <= keep_count:
return return True
if len(session.messages) - session.last_consolidated <= 0: if len(session.messages) - session.last_consolidated <= 0:
return return True
old_messages = session.messages[session.last_consolidated:-keep_count] old_messages = session.messages[session.last_consolidated:-keep_count]
if not old_messages: if not old_messages:
return return True
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count) logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
lines = [] lines = []
@@ -119,7 +122,7 @@ class MemoryStore:
if not response.has_tool_calls: if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping") logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
return return False
args = response.tool_calls[0].arguments args = response.tool_calls[0].arguments
if entry := args.get("history_entry"): if entry := args.get("history_entry"):
@@ -134,5 +137,7 @@ class MemoryStore:
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
except Exception as e: return True
logger.error("Memory consolidation failed: {}", e) except Exception:
logger.exception("Memory consolidation failed")
return False

View File

@@ -13,8 +13,11 @@ def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path |
if not p.is_absolute() and workspace: if not p.is_absolute() and workspace:
p = workspace / p p = workspace / p
resolved = p.resolve() resolved = p.resolve()
if allowed_dir and not str(resolved).startswith(str(allowed_dir.resolve())): if allowed_dir:
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") try:
resolved.relative_to(allowed_dir.resolve())
except ValueError:
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
return resolved return resolved

View File

@@ -1,5 +1,6 @@
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools.""" """MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import Any from typing import Any
@@ -13,12 +14,13 @@ from nanobot.agent.tools.registry import ToolRegistry
class MCPToolWrapper(Tool): class MCPToolWrapper(Tool):
"""Wraps a single MCP server tool as a nanobot Tool.""" """Wraps a single MCP server tool as a nanobot Tool."""
def __init__(self, session, server_name: str, tool_def): def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session self._session = session
self._original_name = tool_def.name self._original_name = tool_def.name
self._name = f"mcp_{server_name}_{tool_def.name}" self._name = f"mcp_{server_name}_{tool_def.name}"
self._description = tool_def.description or tool_def.name self._description = tool_def.description or tool_def.name
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}} self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
self._tool_timeout = tool_timeout
@property @property
def name(self) -> str: def name(self) -> str:
@@ -34,7 +36,14 @@ class MCPToolWrapper(Tool):
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
from mcp import types from mcp import types
result = await self._session.call_tool(self._original_name, arguments=kwargs) try:
result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs),
timeout=self._tool_timeout,
)
except asyncio.TimeoutError:
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
return f"(MCP tool call timed out after {self._tool_timeout}s)"
parts = [] parts = []
for block in result.content: for block in result.content:
if isinstance(block, types.TextContent): if isinstance(block, types.TextContent):
@@ -83,7 +92,7 @@ async def connect_mcp_servers(
tools = await session.list_tools() tools = await session.list_tools()
for tool_def in tools.tools: for tool_def in tools.tools:
wrapper = MCPToolWrapper(session, name, tool_def) wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
registry.register(wrapper) registry.register(wrapper)
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name) logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)

View File

@@ -304,7 +304,8 @@ class EmailChannel(BaseChannel):
self._processed_uids.add(uid) self._processed_uids.add(uid)
# mark_seen is the primary dedup; this set is a safety net # mark_seen is the primary dedup; this set is a safety net
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS: if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
self._processed_uids.clear() # Evict a random half to cap memory; mark_seen is the primary dedup
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
if mark_seen: if mark_seen:
client.store(imap_id, "+FLAGS", "\\Seen") client.store(imap_id, "+FLAGS", "\\Seen")

View File

@@ -55,7 +55,6 @@ class QQChannel(BaseChannel):
self.config: QQConfig = config self.config: QQConfig = config
self._client: "botpy.Client | None" = None self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000) self._processed_ids: deque = deque(maxlen=1000)
self._bot_task: asyncio.Task | None = None
async def start(self) -> None: async def start(self) -> None:
"""Start the QQ bot.""" """Start the QQ bot."""
@@ -71,8 +70,8 @@ class QQChannel(BaseChannel):
BotClass = _make_bot_class(self) BotClass = _make_bot_class(self)
self._client = BotClass() self._client = BotClass()
self._bot_task = asyncio.create_task(self._run_bot())
logger.info("QQ bot started (C2C private message)") logger.info("QQ bot started (C2C private message)")
await self._run_bot()
async def _run_bot(self) -> None: async def _run_bot(self) -> None:
"""Run the bot connection with auto-reconnect.""" """Run the bot connection with auto-reconnect."""
@@ -88,11 +87,10 @@ class QQChannel(BaseChannel):
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the QQ bot.""" """Stop the QQ bot."""
self._running = False self._running = False
if self._bot_task: if self._client:
self._bot_task.cancel()
try: try:
await self._bot_task await self._client.close()
except asyncio.CancelledError: except Exception:
pass pass
logger.info("QQ bot stopped") logger.info("QQ bot stopped")
@@ -130,5 +128,5 @@ class QQChannel(BaseChannel):
content=content, content=content,
metadata={"message_id": data.id}, metadata={"message_id": data.id},
) )
except Exception as e: except Exception:
logger.error("Error handling QQ message: {}", e) logger.exception("Error handling QQ message")

View File

@@ -179,18 +179,21 @@ class SlackChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.debug("Slack reactions_add failed: {}", e) logger.debug("Slack reactions_add failed: {}", e)
await self._handle_message( try:
sender_id=sender_id, await self._handle_message(
chat_id=chat_id, sender_id=sender_id,
content=text, chat_id=chat_id,
metadata={ content=text,
"slack": { metadata={
"event": event, "slack": {
"thread_ts": thread_ts, "event": event,
"channel_type": channel_type, "thread_ts": thread_ts,
} "channel_type": channel_type,
}, }
) },
)
except Exception:
logger.exception("Error handling Slack message from {}", sender_id)
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
if channel_type == "im": if channel_type == "im":

View File

@@ -260,6 +260,7 @@ class MCPServerConfig(Base):
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
url: str = "" # HTTP: streamable HTTP endpoint URL url: str = "" # HTTP: streamable HTTP endpoint URL
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
tool_timeout: int = 30 # Seconds before a tool call is cancelled
class ToolsConfig(Base): class ToolsConfig(Base):

View File

@@ -40,7 +40,7 @@ class CustomProvider(LLMProvider):
return LLMResponse( return LLMResponse(
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop", content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {}, usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
reasoning_content=getattr(msg, "reasoning_content", None), reasoning_content=getattr(msg, "reasoning_content", None) or None,
) )
def get_default_model(self) -> str: def get_default_model(self) -> str:

View File

@@ -257,7 +257,7 @@ class LiteLLMProvider(LLMProvider):
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens,
} }
reasoning_content = getattr(message, "reasoning_content", None) reasoning_content = getattr(message, "reasoning_content", None) or None
return LLMResponse( return LLMResponse(
content=message.content, content=message.content,

View File

@@ -1,6 +1,7 @@
"""Session management for conversation history.""" """Session management for conversation history."""
import json import json
import shutil
from pathlib import Path from pathlib import Path
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
@@ -108,9 +109,11 @@ class SessionManager:
if not path.exists(): if not path.exists():
legacy_path = self._get_legacy_session_path(key) legacy_path = self._get_legacy_session_path(key)
if legacy_path.exists(): if legacy_path.exists():
import shutil try:
shutil.move(str(legacy_path), str(path)) shutil.move(str(legacy_path), str(path))
logger.info("Migrated session {} from legacy path", key) logger.info("Migrated session {} from legacy path", key)
except Exception:
logger.exception("Failed to migrate session {}", key)
if not path.exists(): if not path.exists():
return None return None

View File

@@ -1,5 +1,8 @@
"""Test session management with cache-friendly message handling.""" """Test session management with cache-friendly message handling."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from pathlib import Path from pathlib import Path
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
@@ -475,3 +478,351 @@ class TestEmptyAndBoundarySessions:
expected_count = 60 - KEEP_COUNT - 10 expected_count = 60 - KEEP_COUNT - 10
assert len(old_messages) == expected_count assert len(old_messages) == expected_count
assert_messages_content(old_messages, 10, 34) assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
if archive_all:
return False
return True
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
@pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count == 3, (
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
@pytest.mark.asyncio
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
self, tmp_path: Path
) -> None:
"""/new should remove lock entry for fully invalidated session key."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
# Ensure lock exists before /new.
_ = loop._get_consolidation_lock(session.key)
assert session.key in loop._consolidation_locks
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
assert session.key not in loop._consolidation_locks