Merge branch 'main' into pr-1810
# Conflicts: # nanobot/agent/memory.py # tests/test_memory_consolidation_types.py
This commit is contained in:
@@ -114,6 +114,64 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
|
||||
assert config.get_provider_name() == "openai_codex"
|
||||
|
||||
|
||||
def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "ollama/llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||
config = Config()
|
||||
config.agents.defaults.provider = "ollama"
|
||||
config.agents.defaults.model = "llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_auto_detects_ollama_from_local_api_base():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {
|
||||
"vllm": {"apiBase": "http://localhost:8000"},
|
||||
"ollama": {"apiBase": "http://localhost:11434"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {
|
||||
"vllm": {"apiBase": "http://localhost:8000"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "vllm"
|
||||
assert config.get_api_base() == "http://localhost:8000"
|
||||
|
||||
|
||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
@@ -267,6 +325,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
@@ -328,6 +396,28 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
assert config.workspace_path == override
|
||||
|
||||
|
||||
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.memory_window = 100
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
@@ -356,3 +446,47 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "port 18791" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
88
tests/test_config_migration.py
Normal file
88
tests/test_config_migration.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 1234,
|
||||
"memoryWindow": 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.agents.defaults.max_tokens == 1234
|
||||
assert config.agents.defaults.context_window_tokens == 65_536
|
||||
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||
|
||||
|
||||
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 2222,
|
||||
"memoryWindow": 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
|
||||
assert defaults["maxTokens"] == 2222
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 3333,
|
||||
"memoryWindow": 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
assert defaults["maxTokens"] == 3333
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestConsolidationDeduplicationGuard:
|
||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||
class TestNewCommandArchival:
|
||||
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||
@staticmethod
|
||||
def _make_loop(tmp_path: Path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=1,
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls
|
||||
consolidation_calls += 1
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await loop._process_message(msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 1, (
|
||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls, active, max_active
|
||||
consolidation_calls += 1
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
await asyncio.sleep(0.05)
|
||||
active -= 1
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
await loop._process_message(new_msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 2, (
|
||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||
)
|
||||
assert max_active == 1, (
|
||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||
started.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
await started.wait()
|
||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
assert len(loop._consolidation_tasks) == 0, (
|
||||
"Task reference must be removed after completion"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new waits for in-flight consolidation and archives before clear."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
started.set()
|
||||
await release.wait()
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||
"""/new must keep session data if archive step reports failure."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
if archive_all:
|
||||
return False
|
||||
return True
|
||||
async def _failing_consolidate(_messages) -> bool:
|
||||
return False
|
||||
|
||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "failed" in response.content.lower()
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == before_count, (
|
||||
"Session must remain intact when /new archival fails"
|
||||
)
|
||||
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new should archive only messages not yet consolidated by prior task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
session.last_consolidated = len(session.messages) - 3
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _fake_consolidate(messages) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
|
||||
started.set()
|
||||
await release.wait()
|
||||
sess.last_consolidated = len(sess.messages) - 3
|
||||
archived_count = len(messages)
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done()
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count == 3, (
|
||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||
)
|
||||
assert archived_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||
"""/new clears session and returns confirmation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _ok_consolidate(_messages) -> bool:
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.dingtalk import DingTalkChannel
|
||||
import nanobot.channels.dingtalk as dingtalk_module
|
||||
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
|
||||
@@ -64,3 +66,46 @@ async def test_group_send_uses_group_messages_api() -> None:
|
||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
assert call["json"]["openConversationId"] == "conv123"
|
||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeChatbotMessage:
|
||||
text = None
|
||||
extensions = {"content": {"recognition": "voice transcript"}}
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "audio"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeChatbotMessage()
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "2",
|
||||
"conversationId": "conv123",
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert msg.content == "voice transcript"
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
|
||||
251
tests/test_filesystem_tools.py
Normal file
251
tests/test_filesystem_tools.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import (
|
||||
EditFileTool,
|
||||
ListDirTool,
|
||||
ReadFileTool,
|
||||
_find_match,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_file(self, tmp_path):
|
||||
f = tmp_path / "sample.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
return f
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_read_has_line_numbers(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file))
|
||||
assert "1| line 1" in result
|
||||
assert "20| line 20" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_and_limit(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
|
||||
assert "5| line 5" in result
|
||||
assert "7| line 7" in result
|
||||
assert "8| line 8" not in result
|
||||
assert "Use offset=8 to continue" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_beyond_end(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=999)
|
||||
assert "Error" in result
|
||||
assert "beyond end" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_of_file_marker(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
|
||||
assert "End of file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert "Empty file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.txt"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_char_budget_trims(self, tool, tmp_path):
|
||||
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
|
||||
f = tmp_path / "big.txt"
|
||||
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
|
||||
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
|
||||
assert "Use offset=" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_match (unit tests for the helper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindMatch:
|
||||
|
||||
def test_exact_match(self):
|
||||
match, count = _find_match("hello world", "world")
|
||||
assert match == "world"
|
||||
assert count == 1
|
||||
|
||||
def test_exact_no_match(self):
|
||||
match, count = _find_match("hello world", "xyz")
|
||||
assert match is None
|
||||
assert count == 0
|
||||
|
||||
def test_crlf_normalisation(self):
|
||||
# Caller normalises CRLF before calling _find_match, so test with
|
||||
# pre-normalised content to verify exact match still works.
|
||||
content = "line1\nline2\nline3"
|
||||
old_text = "line1\nline2\nline3"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_line_trim_fallback(self):
|
||||
content = " def foo():\n pass\n"
|
||||
old_text = "def foo():\n pass"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# The returned match should be the *original* indented text
|
||||
assert " def foo():" in match
|
||||
|
||||
def test_line_trim_multiple_candidates(self):
|
||||
content = " a\n b\n a\n b\n"
|
||||
old_text = "a\nb"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert count == 2
|
||||
|
||||
def test_empty_old_text(self):
|
||||
match, count = _find_match("hello", "")
|
||||
# Empty string is always "in" any string via exact match
|
||||
assert match == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EditFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crlf_normalisation(self, tool, tmp_path):
|
||||
f = tmp_path / "crlf.py"
|
||||
f.write_bytes(b"line1\r\nline2\r\nline3")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
raw = f.read_bytes()
|
||||
assert b"LINE1" in raw
|
||||
# CRLF line endings should be preserved throughout the file
|
||||
assert b"\r\n" in raw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(" def foo():\n pass\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert "bar" in f.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears" in result.lower() or "Warning" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all(self, tool, tmp_path):
|
||||
f = tmp_path / "multi.py"
|
||||
f.write_text("foo bar foo bar foo", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="foo", new_text="baz", replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "baz bar baz bar baz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
f = tmp_path / "nf.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ListDirTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListDirTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ListDirTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def populated_dir(self, tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "main.py").write_text("pass")
|
||||
(tmp_path / "src" / "utils.py").write_text("pass")
|
||||
(tmp_path / "README.md").write_text("hi")
|
||||
(tmp_path / ".git").mkdir()
|
||||
(tmp_path / ".git" / "config").write_text("x")
|
||||
(tmp_path / "node_modules").mkdir()
|
||||
(tmp_path / "node_modules" / "pkg").mkdir()
|
||||
return tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_list(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir))
|
||||
assert "README.md" in result
|
||||
assert "src" in result
|
||||
# .git and node_modules should be ignored
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir), recursive=True)
|
||||
assert "src/main.py" in result
|
||||
assert "src/utils.py" in result
|
||||
assert "README.md" in result
|
||||
# Ignored dirs should not appear
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_entries_truncation(self, tool, tmp_path):
|
||||
for i in range(10):
|
||||
(tmp_path / f"file_{i}.txt").write_text("x")
|
||||
result = await tool.execute(path=str(tmp_path), max_entries=3)
|
||||
assert "truncated" in result
|
||||
assert "3 of 10" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_dir(self, tool, tmp_path):
|
||||
d = tmp_path / "empty"
|
||||
d.mkdir()
|
||||
result = await tool.execute(path=str(d))
|
||||
assert "empty" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
53
tests/test_gemini_thought_signature.py
Normal file
53
tests/test_gemini_thought_signature.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
|
||||
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
|
||||
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
|
||||
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="tool_calls",
|
||||
message=SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="call_123",
|
||||
function=SimpleNamespace(
|
||||
name="read_file",
|
||||
arguments='{"path":"todo.md"}',
|
||||
provider_specific_fields={"inner": "value"},
|
||||
),
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
parsed = provider._parse_response(response)
|
||||
|
||||
assert len(parsed.tool_calls) == 1
|
||||
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
|
||||
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
|
||||
|
||||
|
||||
def test_tool_call_request_serializes_provider_fields() -> None:
|
||||
tool_call = ToolCallRequest(
|
||||
id="abc123xyz",
|
||||
name="read_file",
|
||||
arguments={"path": "todo.md"},
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
function_provider_specific_fields={"inner": "value"},
|
||||
)
|
||||
|
||||
message = tool_call.to_openai_tool_call()
|
||||
|
||||
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
|
||||
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
assert message["function"]["arguments"] == '{"path": "todo.md"}'
|
||||
@@ -3,18 +3,24 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
class DummyProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path) -> None:
|
||||
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
||||
)
|
||||
|
||||
assert await service.trigger_now() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||
provider = DummyProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
|
||||
assert action == "run"
|
||||
assert tasks == "check open tasks"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
190
tests/test_loop_consolidation_tokens.py
Normal file
190
tests/test_loop_consolidation_tokens.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||
assert session.last_consolidated == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (300, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (150, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||
order: list[str] = []
|
||||
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
|
||||
async def track_consolidate(messages):
|
||||
order.append("consolidate")
|
||||
return True
|
||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||
|
||||
async def track_llm(*args, **kwargs):
|
||||
order.append("llm")
|
||||
return LLMResponse(content="ok", tool_calls=[])
|
||||
loop.provider.chat_with_retry = track_llm
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert "consolidate" in order
|
||||
assert "llm" in order
|
||||
assert order.index("consolidate") < order.index("llm")
|
||||
@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = 500
|
||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
||||
return loop
|
||||
|
||||
|
||||
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||
|
||||
|
||||
def test_save_turn_keeps_tool_results_under_16k() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:tool-result")
|
||||
content = "x" * 12_000
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
|
||||
skip=0,
|
||||
)
|
||||
|
||||
assert session.messages[0]["content"] == content
|
||||
|
||||
@@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||
"""Create a mock session with messages."""
|
||||
session = MagicMock()
|
||||
session.messages = [
|
||||
def _make_messages(message_count: int = 30):
|
||||
"""Create a list of mock messages."""
|
||||
return [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
session.last_consolidated = 0
|
||||
return session
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
@@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update):
|
||||
)
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
class TestMemoryConsolidationTypeHandling:
|
||||
"""Test that consolidation handles various argument types correctly."""
|
||||
|
||||
@@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
@@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
@@ -111,9 +126,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
@@ -126,21 +142,23 @@ class TestMemoryConsolidationTypeHandling:
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when messages < keep_count."""
|
||||
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
session = _make_session(message_count=10)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages: list[dict] = []
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
@@ -165,9 +183,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
@@ -190,9 +209,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -213,9 +233,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -224,7 +245,7 @@ class TestMemoryConsolidationTypeHandling:
|
||||
"""Do not persist partial results when required fields are missing."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
@@ -236,21 +257,20 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Do not append history if memory_update is missing."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
@@ -262,51 +282,152 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Null required fields should be rejected before persistence."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry=None,
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Empty history entries should be rejected to avoid blank archival records."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry=" ",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="503 server error", finish_reason="error"),
|
||||
_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
),
|
||||
])
|
||||
messages = _make_messages(message_count=60)
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
|
||||
"""Consolidation no longer passes generation params — the provider owns them."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat_with_retry.assert_awaited_once()
|
||||
_, kwargs = provider.chat_with_retry.await_args
|
||||
assert kwargs["model"] == "test-model"
|
||||
assert "temperature" not in kwargs
|
||||
assert "max_tokens" not in kwargs
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
|
||||
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
||||
store = MemoryStore(tmp_path)
|
||||
error_resp = LLMResponse(
|
||||
content="Error calling LLM: litellm.BadRequestError: "
|
||||
"The tool_choice parameter does not support being set to required or object",
|
||||
finish_reason="error",
|
||||
tool_calls=[],
|
||||
)
|
||||
ok_resp = _make_tool_response(
|
||||
history_entry="[2026-01-01] Fallback worked.",
|
||||
memory_update="# Memory\nFallback OK.",
|
||||
)
|
||||
|
||||
call_log: list[dict] = []
|
||||
|
||||
async def _tracking_chat(**kwargs):
|
||||
call_log.append(kwargs)
|
||||
return error_resp if len(call_log) == 1 else ok_resp
|
||||
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert len(call_log) == 2
|
||||
assert isinstance(call_log[0]["tool_choice"], dict)
|
||||
assert call_log[1]["tool_choice"] == "auto"
|
||||
assert "Fallback worked." in store.history_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
|
||||
"""Forced rejected, auto retry also produces no tool call -> return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
error_resp = LLMResponse(
|
||||
content="Error: tool_choice must be none or auto",
|
||||
finish_reason="error",
|
||||
tool_calls=[],
|
||||
)
|
||||
no_tool_resp = LLMResponse(
|
||||
content="Here is a summary.",
|
||||
finish_reason="stop",
|
||||
tool_calls=[],
|
||||
)
|
||||
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
|
||||
125
tests/test_provider_retry.py
Normal file
125
tests/test_provider_retry.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
self.last_kwargs: dict = {}
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
self.last_kwargs = kwargs
|
||||
response = self._responses.pop(0)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "401 unauthorized"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit a", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit b", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit c", finish_reason="error"),
|
||||
LLMResponse(content="503 final server error", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "503 final server error"
|
||||
assert provider.calls == 4
|
||||
assert delays == [1, 2, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||
provider = ScriptedProvider([asyncio.CancelledError()])
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||
"""When callers omit generation params, provider.generation defaults are used."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.2
|
||||
assert provider.last_kwargs["max_tokens"] == 321
|
||||
assert provider.last_kwargs["reasoning_effort"] == "high"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||
"""Explicit kwargs should override provider.generation defaults."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.9,
|
||||
max_tokens=9999,
|
||||
reasoning_effort="low",
|
||||
)
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.9
|
||||
assert provider.last_kwargs["max_tokens"] == 9999
|
||||
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||
@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
||||
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
@@ -60,7 +60,37 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call["group_openid"] == "group123"
|
||||
assert call["msg_id"] == "msg1"
|
||||
assert call["msg_seq"] == 2
|
||||
assert call == {
|
||||
"group_openid": "group123",
|
||||
"msg_type": 0,
|
||||
"content": "hello",
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
assert not channel._client.api.c2c_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.c2c_calls) == 1
|
||||
call = channel._client.api.c2c_calls[0]
|
||||
assert call == {
|
||||
"openid": "user123",
|
||||
"msg_type": 0,
|
||||
"content": "hello",
|
||||
"msg_id": "msg1",
|
||||
"msg_seq": 2,
|
||||
}
|
||||
assert not channel._client.api.group_calls
|
||||
|
||||
76
tests/test_restart_command.py
Normal file
76
tests/test_restart_command.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for /restart slash command."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
|
||||
def _make_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager"):
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestRestartCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_sends_message_and_calls_execv(self):
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
|
||||
|
||||
with patch("nanobot.agent.loop.os.execv") as mock_execv:
|
||||
await loop._handle_restart(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "Restarting" in out.content
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
mock_execv.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_intercepted_in_run_loop(self):
|
||||
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
|
||||
|
||||
with patch.object(loop, "_handle_restart") as mock_handle:
|
||||
mock_handle.return_value = None
|
||||
await bus.publish_inbound(msg)
|
||||
|
||||
loop._running = True
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
loop._running = False
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_includes_restart(self):
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
|
||||
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "/restart" in response.content
|
||||
127
tests/test_skill_creator_scripts.py
Normal file
127
tests/test_skill_creator_scripts.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import importlib
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
|
||||
if str(SCRIPT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SCRIPT_DIR))
|
||||
|
||||
init_skill = importlib.import_module("init_skill")
|
||||
package_skill = importlib.import_module("package_skill")
|
||||
quick_validate = importlib.import_module("quick_validate")
|
||||
|
||||
|
||||
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
|
||||
skill_dir = init_skill.init_skill(
|
||||
"demo-skill",
|
||||
tmp_path,
|
||||
["scripts", "references", "assets"],
|
||||
include_examples=True,
|
||||
)
|
||||
|
||||
assert skill_dir == tmp_path / "demo-skill"
|
||||
assert (skill_dir / "SKILL.md").exists()
|
||||
assert (skill_dir / "scripts" / "example.py").exists()
|
||||
assert (skill_dir / "references" / "api_reference.md").exists()
|
||||
assert (skill_dir / "assets" / "example_asset.txt").exists()
|
||||
|
||||
|
||||
def test_validate_skill_accepts_existing_skill_creator() -> None:
|
||||
valid, message = quick_validate.validate_skill(
|
||||
Path("nanobot/skills/skill-creator").resolve()
|
||||
)
|
||||
|
||||
assert valid, message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "placeholder-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: placeholder-skill\n"
|
||||
'description: "[TODO: fill me in]"\n'
|
||||
"---\n"
|
||||
"# Placeholder\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "TODO placeholder" in message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "bad-root-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: bad-root-skill\n"
|
||||
"description: Valid description\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "Unexpected file or directory in skill root" in message
|
||||
|
||||
|
||||
def test_package_skill_creates_archive(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "package-me"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: package-me\n"
|
||||
"description: Package this skill.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path == (tmp_path / "dist" / "package-me.skill")
|
||||
assert archive_path.exists()
|
||||
with zipfile.ZipFile(archive_path, "r") as archive:
|
||||
names = set(archive.namelist())
|
||||
assert "package-me/SKILL.md" in names
|
||||
assert "package-me/scripts/helper.py" in names
|
||||
|
||||
|
||||
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "symlink-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: symlink-skill\n"
|
||||
"description: Reject symlinks during packaging.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
target = tmp_path / "outside.txt"
|
||||
target.write_text("secret\n", encoding="utf-8")
|
||||
link = scripts_dir / "outside.txt"
|
||||
|
||||
try:
|
||||
link.symlink_to(target)
|
||||
except (OSError, NotImplementedError):
|
||||
return
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path is None
|
||||
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()
|
||||
@@ -165,3 +165,46 @@ class TestSubagentCancellation:
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def scripted_chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
provider.chat_with_retry = scripted_chat_with_retry
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
@@ -27,9 +30,11 @@ class _FakeUpdater:
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
self.get_me_calls = 0
|
||||
|
||||
async def get_me(self):
|
||||
return SimpleNamespace(username="nanobot_test")
|
||||
self.get_me_calls += 1
|
||||
return SimpleNamespace(id=999, username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
@@ -37,6 +42,15 @@ class _FakeBot:
|
||||
async def send_message(self, **kwargs) -> None:
|
||||
self.sent_messages.append(kwargs)
|
||||
|
||||
async def send_chat_action(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def get_file(self, file_id: str):
|
||||
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
|
||||
async def _fake_download(path) -> None:
|
||||
pass
|
||||
return SimpleNamespace(download_to_drive=_fake_download)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
@@ -87,6 +101,35 @@ class _FakeBuilder:
|
||||
return self.app
|
||||
|
||||
|
||||
def _make_telegram_update(
|
||||
*,
|
||||
chat_type: str = "group",
|
||||
text: str | None = None,
|
||||
caption: str | None = None,
|
||||
entities=None,
|
||||
caption_entities=None,
|
||||
reply_to_message=None,
|
||||
):
|
||||
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type=chat_type, is_forum=False),
|
||||
chat_id=-100123,
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
reply_to_message=reply_to_message,
|
||||
photo=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
media_group_id=None,
|
||||
message_thread_id=None,
|
||||
message_id=1,
|
||||
)
|
||||
return SimpleNamespace(message=message, effective_user=user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||
config = TelegramConfig(
|
||||
@@ -131,6 +174,10 @@ def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||
|
||||
|
||||
def test_telegram_group_policy_defaults_to_mention() -> None:
|
||||
assert TelegramConfig().group_policy == "mention"
|
||||
|
||||
|
||||
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||
|
||||
@@ -182,3 +229,371 @@ async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
|
||||
|
||||
assert handled == []
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
|
||||
|
||||
assert len(handled) == 2
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_caption_mention() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(
|
||||
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
|
||||
None,
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "@nanobot_test photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello group"), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert channel._app.bot.get_me_calls == 0
|
||||
|
||||
|
||||
def test_extract_reply_context_no_reply() -> None:
|
||||
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||
message = SimpleNamespace(reply_to_message=None)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
def test_extract_reply_context_with_text() -> None:
|
||||
"""When reply has text, return prefixed string."""
|
||||
reply = SimpleNamespace(text="Hello world", caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
||||
|
||||
|
||||
def test_extract_reply_context_with_caption_only() -> None:
|
||||
"""When reply has only caption (no text), caption is used."""
|
||||
reply = SimpleNamespace(text=None, caption="Photo caption")
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
||||
|
||||
|
||||
def test_extract_reply_context_truncation() -> None:
|
||||
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||
reply = SimpleNamespace(text=long_text, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
result = TelegramChannel._extract_reply_context(message)
|
||||
assert result is not None
|
||||
assert result.startswith("[Reply to: ")
|
||||
assert result.endswith("...]")
|
||||
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_extract_reply_context_no_text_returns_none() -> None:
|
||||
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||
reply = SimpleNamespace(text=None, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_includes_reply_context() -> None:
|
||||
"""When user replies to a message, content passed to bus starts with reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="translate this", reply_to_message=reply)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: Hello]")
|
||||
assert "translate this" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_message_media_returns_path_when_download_succeeds(
|
||||
monkeypatch, tmp_path
|
||||
) -> None:
|
||||
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
|
||||
msg = SimpleNamespace(
|
||||
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
paths, parts = await channel._download_message_media(msg)
|
||||
assert len(paths) == 1
|
||||
assert len(parts) == 1
|
||||
assert "fid123" in paths[0]
|
||||
assert "[image:" in parts[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
|
||||
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what is the image?",
|
||||
reply_to_message=reply_with_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: [image:")
|
||||
assert "what is the image?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "reply_photo_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
|
||||
"""When reply has media but download fails, no media attached and no reply tag."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = None
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "what is this?" in handled[0]["content"]
|
||||
assert handled[0]["media"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
|
||||
"""When replying to a message with caption + photo, both text context and media are included."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_caption_and_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption="A cute cat",
|
||||
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what breed is this?",
|
||||
reply_to_message=reply_with_caption_and_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "[Reply to: A cute cat]" in handled[0]["content"]
|
||||
assert "what breed is this?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "cat_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||
"""Slash commands forwarded via _forward_command must not include reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
|
||||
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="/new", reply_to_message=reply)
|
||||
await channel._forward_command(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
|
||||
@@ -108,6 +108,32 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
|
||||
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
assert "~/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
|
||||
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
|
||||
|
||||
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
@@ -337,3 +363,44 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
|
||||
# --- ExecTool enhancement tests ---
|
||||
|
||||
|
||||
async def test_exec_always_returns_exit_code() -> None:
|
||||
"""Exit code should appear in output even on success (exit 0)."""
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "Exit code: 0" in result
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
async def test_exec_head_tail_truncation() -> None:
|
||||
"""Long output should preserve both head and tail."""
|
||||
tool = ExecTool()
|
||||
# Generate output that exceeds _MAX_OUTPUT
|
||||
big = "A" * 6000 + "\n" + "B" * 6000
|
||||
result = await tool.execute(command=f"echo '{big}'")
|
||||
assert "chars truncated" in result
|
||||
# Head portion should start with As
|
||||
assert result.startswith("A")
|
||||
# Tail portion should end with the exit code which comes after Bs
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_parameter() -> None:
|
||||
"""LLM-supplied timeout should override the constructor default."""
|
||||
tool = ExecTool(timeout=60)
|
||||
# A very short timeout should cause the command to be killed
|
||||
result = await tool.execute(command="sleep 10", timeout=1)
|
||||
assert "timed out" in result
|
||||
assert "1 seconds" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_capped_at_max() -> None:
|
||||
"""Timeout values above _MAX_TIMEOUT should be clamped."""
|
||||
tool = ExecTool()
|
||||
# Should not raise — just clamp to 600
|
||||
result = await tool.execute(command="echo ok", timeout=9999)
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
Reference in New Issue
Block a user