Merge remote-tracking branch 'origin/main' into pr-420
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Test session management with cache-friendly message handling."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
@@ -475,3 +478,351 @@ class TestEmptyAndBoundarySessions:
|
||||
expected_count = 60 - KEEP_COUNT - 10
|
||||
assert len(old_messages) == expected_count
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestConsolidationDeduplicationGuard:
|
||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls
|
||||
consolidation_calls += 1
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await loop._process_message(msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 1, (
|
||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls, active, max_active
|
||||
consolidation_calls += 1
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
await asyncio.sleep(0.05)
|
||||
active -= 1
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
await loop._process_message(new_msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 2, (
|
||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||
)
|
||||
assert max_active == 1, (
|
||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||
started.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
await started.wait()
|
||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
assert len(loop._consolidation_tasks) == 0, (
|
||||
"Task reference must be removed after completion"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new waits for in-flight consolidation and archives before clear."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
started.set()
|
||||
await release.wait()
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||
"""/new must keep session data if archive step reports failure."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
if archive_all:
|
||||
return False
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "failed" in response.content.lower()
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == before_count, (
|
||||
"Session must remain intact when /new archival fails"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new should archive only messages not yet consolidated by prior task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
|
||||
started.set()
|
||||
await release.wait()
|
||||
sess.last_consolidated = len(sess.messages) - 3
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done()
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count == 3, (
|
||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new should remove lock entry for fully invalidated session key."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
# Ensure lock exists before /new.
|
||||
_ = loop._get_consolidation_lock(session.key)
|
||||
assert session.key in loop._consolidation_locks
|
||||
|
||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert session.key not in loop._consolidation_locks
|
||||
|
||||
66
tests/test_context_prompt_cache.py
Normal file
66
tests/test_context_prompt_cache.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for cache-friendly prompt construction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime as real_datetime
|
||||
from pathlib import Path
|
||||
import datetime as datetime_module
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
class _FakeDatetime(real_datetime):
|
||||
current = real_datetime(2026, 2, 24, 13, 59)
|
||||
|
||||
@classmethod
|
||||
def now(cls, tz=None): # type: ignore[override]
|
||||
return cls.current
|
||||
|
||||
|
||||
def _make_workspace(tmp_path: Path) -> Path:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
return workspace
|
||||
|
||||
|
||||
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||
"""System prompt should not change just because wall clock minute changes."""
|
||||
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
|
||||
prompt1 = builder.build_system_prompt()
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
|
||||
prompt2 = builder.build_system_prompt()
|
||||
|
||||
assert prompt1 == prompt2
|
||||
|
||||
|
||||
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
"""Runtime metadata should be a separate user message before the actual user message."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[],
|
||||
current_message="Return exactly: OK",
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
)
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "## Current Session" not in messages[0]["content"]
|
||||
|
||||
assert messages[-2]["role"] == "user"
|
||||
runtime_content = messages[-2]["content"]
|
||||
assert isinstance(runtime_content, str)
|
||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
|
||||
assert "Current Time:" in runtime_content
|
||||
assert "Channel: cli" in runtime_content
|
||||
assert "Chat ID: direct" in runtime_content
|
||||
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == "Return exactly: OK"
|
||||
29
tests/test_cron_commands.py
Normal file
29
tests/test_cron_commands.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
|
||||
monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"cron",
|
||||
"add",
|
||||
"--name",
|
||||
"demo",
|
||||
"--message",
|
||||
"hello",
|
||||
"--cron",
|
||||
"0 9 * * *",
|
||||
"--tz",
|
||||
"America/Vancovuer",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
|
||||
assert not (tmp_path / "cron" / "jobs.json").exists()
|
||||
30
tests/test_cron_service.py
Normal file
30
tests/test_cron_service.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
|
||||
service.add_job(
|
||||
name="tz typo",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert service.list_jobs(include_disabled=True) == []
|
||||
|
||||
|
||||
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
job = service.add_job(
|
||||
name="tz ok",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert job.schedule.tz == "America/Vancouver"
|
||||
assert job.state.next_run_at_ms is not None
|
||||
@@ -169,7 +169,8 @@ async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
@@ -201,6 +202,11 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# Mark alice as someone who sent us an email (making this a "reply")
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||
|
||||
# Reply should be skipped (auto_reply_enabled=False)
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
@@ -210,6 +216,7 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
)
|
||||
assert fake_instances == []
|
||||
|
||||
# Reply with force_send=True should be sent
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
@@ -222,6 +229,56 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# bob@example.com has never sent us an email (proactive send)
|
||||
# This should be sent even with auto_reply_enabled=False
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="bob@example.com",
|
||||
content="Hello, this is a proactive email.",
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
sent = fake_instances[0].sent_messages[0]
|
||||
assert sent["To"] == "bob@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
|
||||
44
tests/test_heartbeat_service.py
Normal file
44
tests/test_heartbeat_service.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.heartbeat.service import (
|
||||
HEARTBEAT_OK_TOKEN,
|
||||
HeartbeatService,
|
||||
)
|
||||
|
||||
|
||||
def test_heartbeat_ok_detection() -> None:
|
||||
def is_ok(response: str) -> bool:
|
||||
return HEARTBEAT_OK_TOKEN in response.upper()
|
||||
|
||||
assert is_ok("HEARTBEAT_OK")
|
||||
assert is_ok("`HEARTBEAT_OK`")
|
||||
assert is_ok("**HEARTBEAT_OK**")
|
||||
assert is_ok("heartbeat_ok")
|
||||
assert is_ok("HEARTBEAT_OK.")
|
||||
|
||||
assert not is_ok("HEARTBEAT_NOT_OK")
|
||||
assert not is_ok("all good")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path) -> None:
|
||||
async def _on_heartbeat(_: str) -> str:
|
||||
return "HEARTBEAT_OK"
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
on_heartbeat=_on_heartbeat,
|
||||
interval_s=9999,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
await service.start()
|
||||
first_task = service._task
|
||||
await service.start()
|
||||
|
||||
assert service._task is first_task
|
||||
|
||||
service.stop()
|
||||
await asyncio.sleep(0)
|
||||
147
tests/test_memory_consolidation_types.py
Normal file
147
tests/test_memory_consolidation_types.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
|
||||
|
||||
Regression test for https://github.com/HKUDS/nanobot/issues/1042
|
||||
When memory consolidation receives dict values instead of strings from the LLM
|
||||
tool call response, it should serialize them to JSON instead of raising TypeError.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||
"""Create a mock session with messages."""
|
||||
session = MagicMock()
|
||||
session.messages = [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
session.last_consolidated = 0
|
||||
return session
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
"""Create an LLMResponse with a save_memory tool call."""
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={
|
||||
"history_entry": history_entry,
|
||||
"memory_update": memory_update,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryConsolidationTypeHandling:
|
||||
"""Test that consolidation handles various argument types correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_work(self, tmp_path: Path) -> None:
|
||||
"""Normal case: LLM returns string arguments."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
|
||||
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
|
||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
history_content = store.history_file.read_text()
|
||||
parsed = json.loads(history_content.strip())
|
||||
assert parsed["summary"] == "User discussed testing."
|
||||
|
||||
memory_content = store.memory_file.read_text()
|
||||
parsed_mem = json.loads(memory_content)
|
||||
assert "User likes testing" in parsed_mem["facts"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a JSON string instead of parsed dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a JSON string (not yet parsed)
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=json.dumps({
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}),
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
|
||||
"""When LLM doesn't use the save_memory tool, return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when messages < keep_count."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
session = _make_session(message_count=10)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
167
tests/test_task_cancel.py
Normal file
167
tests/test_task_cancel.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for /stop task cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestHandleStop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_no_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "No active task" in out.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow_task():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow_task())
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = [task]
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
|
||||
assert cancelled.is_set()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "stopped" in out.content.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_multiple_tasks(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
events = [asyncio.Event(), asyncio.Event()]
|
||||
|
||||
async def slow(idx):
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
events[idx].set()
|
||||
raise
|
||||
|
||||
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = tasks
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
|
||||
assert all(e.is_set() for e in events)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "2 task" in out.content
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_processes_and_publishes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
|
||||
loop._process_message = AsyncMock(
|
||||
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||
)
|
||||
await loop._dispatch(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_lock_serializes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
order = []
|
||||
|
||||
async def mock_process(m, **kwargs):
|
||||
order.append(f"start-{m.content}")
|
||||
await asyncio.sleep(0.05)
|
||||
order.append(f"end-{m.content}")
|
||||
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||
|
||||
loop._process_message = mock_process
|
||||
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||
|
||||
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||
await asyncio.gather(t1, t2)
|
||||
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||
|
||||
|
||||
class TestSubagentCancellation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow())
|
||||
await asyncio.sleep(0)
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_no_tasks(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
Reference in New Issue
Block a user