Merge remote-tracking branch 'origin/main' into pr-1796
This commit is contained in:
22
README.md
22
README.md
@@ -1112,6 +1112,28 @@ Use `toolTimeout` to override the default 30s per-call timeout for slow servers:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Use `enabledTools` to register only a subset of tools from an MCP server:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": {
|
||||||
|
"mcpServers": {
|
||||||
|
"filesystem": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"],
|
||||||
|
"enabledTools": ["read_file", "mcp_filesystem_write_file"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`enabledTools` accepts either the raw MCP tool name (for example `read_file`) or the wrapped nanobot tool name (for example `mcp_filesystem_write_file`).
|
||||||
|
|
||||||
|
- Omit `enabledTools`, or set it to `["*"]`, to register all tools.
|
||||||
|
- Set `enabledTools` to `[]` to register no tools from that server.
|
||||||
|
- Set `enabledTools` to a non-empty list of names to register only that subset.
|
||||||
|
|
||||||
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -258,6 +258,9 @@ class AgentLoop:
|
|||||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error consuming inbound message: {}, continuing...", e)
|
||||||
|
continue
|
||||||
|
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/stop":
|
if cmd == "/stop":
|
||||||
|
|||||||
@@ -138,11 +138,47 @@ async def connect_mcp_servers(
|
|||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
tools = await session.list_tools()
|
tools = await session.list_tools()
|
||||||
|
enabled_tools = set(cfg.enabled_tools)
|
||||||
|
allow_all_tools = "*" in enabled_tools
|
||||||
|
registered_count = 0
|
||||||
|
matched_enabled_tools: set[str] = set()
|
||||||
|
available_raw_names = [tool_def.name for tool_def in tools.tools]
|
||||||
|
available_wrapped_names = [f"mcp_{name}_{tool_def.name}" for tool_def in tools.tools]
|
||||||
for tool_def in tools.tools:
|
for tool_def in tools.tools:
|
||||||
|
wrapped_name = f"mcp_{name}_{tool_def.name}"
|
||||||
|
if (
|
||||||
|
not allow_all_tools
|
||||||
|
and tool_def.name not in enabled_tools
|
||||||
|
and wrapped_name not in enabled_tools
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"MCP: skipping tool '{}' from server '{}' (not in enabledTools)",
|
||||||
|
wrapped_name,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
continue
|
||||||
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
|
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
|
||||||
registry.register(wrapper)
|
registry.register(wrapper)
|
||||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||||
|
registered_count += 1
|
||||||
|
if enabled_tools:
|
||||||
|
if tool_def.name in enabled_tools:
|
||||||
|
matched_enabled_tools.add(tool_def.name)
|
||||||
|
if wrapped_name in enabled_tools:
|
||||||
|
matched_enabled_tools.add(wrapped_name)
|
||||||
|
|
||||||
logger.info("MCP server '{}': connected, {} tools registered", name, len(tools.tools))
|
if enabled_tools and not allow_all_tools:
|
||||||
|
unmatched_enabled_tools = sorted(enabled_tools - matched_enabled_tools)
|
||||||
|
if unmatched_enabled_tools:
|
||||||
|
logger.warning(
|
||||||
|
"MCP server '{}': enabledTools entries not found: {}. Available raw names: {}. "
|
||||||
|
"Available wrapped names: {}",
|
||||||
|
name,
|
||||||
|
", ".join(unmatched_enabled_tools),
|
||||||
|
", ".join(available_raw_names) or "(none)",
|
||||||
|
", ".join(available_wrapped_names) or "(none)",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("MCP server '{}': connected, {} tools registered", name, registered_count)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||||
|
|||||||
@@ -453,6 +453,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
"🐈 nanobot commands:\n"
|
"🐈 nanobot commands:\n"
|
||||||
"/new — Start a new conversation\n"
|
"/new — Start a new conversation\n"
|
||||||
"/stop — Stop the current task\n"
|
"/stop — Stop the current task\n"
|
||||||
|
"/restart — Restart the bot\n"
|
||||||
"/help — Show available commands"
|
"/help — Show available commands"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -448,13 +448,14 @@ def gateway(
|
|||||||
"""Execute a cron job through the agent."""
|
"""Execute a cron job through the agent."""
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.utils.evaluator import evaluate_response
|
||||||
|
|
||||||
reminder_note = (
|
reminder_note = (
|
||||||
"[Scheduled Task] Timer finished.\n\n"
|
"[Scheduled Task] Timer finished.\n\n"
|
||||||
f"Task '{job.name}' has been triggered.\n"
|
f"Task '{job.name}' has been triggered.\n"
|
||||||
f"Scheduled instruction: {job.payload.message}"
|
f"Scheduled instruction: {job.payload.message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prevent the agent from scheduling new cron jobs during execution
|
|
||||||
cron_tool = agent.tools.get("cron")
|
cron_tool = agent.tools.get("cron")
|
||||||
cron_token = None
|
cron_token = None
|
||||||
if isinstance(cron_tool, CronTool):
|
if isinstance(cron_tool, CronTool):
|
||||||
@@ -475,12 +476,16 @@ def gateway(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
if job.payload.deliver and job.payload.to and response:
|
if job.payload.deliver and job.payload.to and response:
|
||||||
from nanobot.bus.events import OutboundMessage
|
should_notify = await evaluate_response(
|
||||||
await bus.publish_outbound(OutboundMessage(
|
response, job.payload.message, provider, agent.model,
|
||||||
channel=job.payload.channel or "cli",
|
)
|
||||||
chat_id=job.payload.to,
|
if should_notify:
|
||||||
content=response
|
from nanobot.bus.events import OutboundMessage
|
||||||
))
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=job.payload.channel or "cli",
|
||||||
|
chat_id=job.payload.to,
|
||||||
|
content=response,
|
||||||
|
))
|
||||||
return response
|
return response
|
||||||
cron.on_job = on_cron_job
|
cron.on_job = on_cron_job
|
||||||
|
|
||||||
@@ -559,6 +564,10 @@ def gateway(
|
|||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\nShutting down...")
|
console.print("\nShutting down...")
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
console.print("\n[red]Error: Gateway crashed unexpectedly[/red]")
|
||||||
|
console.print(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
await agent.close_mcp()
|
await agent.close_mcp()
|
||||||
heartbeat.stop()
|
heartbeat.stop()
|
||||||
@@ -809,7 +818,8 @@ def _get_bridge_dir() -> Path:
|
|||||||
return user_bridge
|
return user_bridge
|
||||||
|
|
||||||
# Check for npm
|
# Check for npm
|
||||||
if not shutil.which("npm"):
|
npm_path = shutil.which("npm")
|
||||||
|
if not npm_path:
|
||||||
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
|
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
@@ -839,10 +849,10 @@ def _get_bridge_dir() -> Path:
|
|||||||
# Install and build
|
# Install and build
|
||||||
try:
|
try:
|
||||||
console.print(" Installing dependencies...")
|
console.print(" Installing dependencies...")
|
||||||
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
|
subprocess.run([npm_path, "install"], cwd=user_bridge, check=True, capture_output=True)
|
||||||
|
|
||||||
console.print(" Building...")
|
console.print(" Building...")
|
||||||
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
|
subprocess.run([npm_path, "run", "build"], cwd=user_bridge, check=True, capture_output=True)
|
||||||
|
|
||||||
console.print("[green]✓[/green] Bridge ready\n")
|
console.print("[green]✓[/green] Bridge ready\n")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
@@ -857,6 +867,7 @@ def _get_bridge_dir() -> Path:
|
|||||||
@channels_app.command("login")
|
@channels_app.command("login")
|
||||||
def channels_login():
|
def channels_login():
|
||||||
"""Link device via QR code."""
|
"""Link device via QR code."""
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from nanobot.config.loader import load_config
|
from nanobot.config.loader import load_config
|
||||||
@@ -875,12 +886,15 @@ def channels_login():
|
|||||||
env["BRIDGE_TOKEN"] = bridge_token
|
env["BRIDGE_TOKEN"] = bridge_token
|
||||||
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
|
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
|
||||||
|
|
||||||
|
npm_path = shutil.which("npm")
|
||||||
|
if not npm_path:
|
||||||
|
console.print("[red]npm not found. Please install Node.js.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
subprocess.run([npm_path, "start"], cwd=bridge_dir, check=True, env=env)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
console.print(f"[red]Bridge failed: {e}[/red]")
|
console.print(f"[red]Bridge failed: {e}[/red]")
|
||||||
except FileNotFoundError:
|
|
||||||
console.print("[red]npm not found. Please install Node.js.[/red]")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class MCPServerConfig(Base):
|
|||||||
url: str = "" # HTTP/SSE: endpoint URL
|
url: str = "" # HTTP/SSE: endpoint URL
|
||||||
headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
|
headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
|
||||||
tool_timeout: int = 30 # seconds before a tool call is cancelled
|
tool_timeout: int = 30 # seconds before a tool call is cancelled
|
||||||
|
enabled_tools: list[str] = Field(default_factory=lambda: ["*"]) # Only register these tools; accepts raw MCP names or wrapped mcp_<server>_<tool> names; ["*"] = all tools; [] = no tools
|
||||||
|
|
||||||
class ToolsConfig(Base):
|
class ToolsConfig(Base):
|
||||||
"""Tools configuration."""
|
"""Tools configuration."""
|
||||||
|
|||||||
@@ -139,6 +139,8 @@ class HeartbeatService:
|
|||||||
|
|
||||||
async def _tick(self) -> None:
|
async def _tick(self) -> None:
|
||||||
"""Execute a single heartbeat tick."""
|
"""Execute a single heartbeat tick."""
|
||||||
|
from nanobot.utils.evaluator import evaluate_response
|
||||||
|
|
||||||
content = self._read_heartbeat_file()
|
content = self._read_heartbeat_file()
|
||||||
if not content:
|
if not content:
|
||||||
logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
|
logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
|
||||||
@@ -156,9 +158,16 @@ class HeartbeatService:
|
|||||||
logger.info("Heartbeat: tasks found, executing...")
|
logger.info("Heartbeat: tasks found, executing...")
|
||||||
if self.on_execute:
|
if self.on_execute:
|
||||||
response = await self.on_execute(tasks)
|
response = await self.on_execute(tasks)
|
||||||
if response and self.on_notify:
|
|
||||||
logger.info("Heartbeat: completed, delivering response")
|
if response:
|
||||||
await self.on_notify(response)
|
should_notify = await evaluate_response(
|
||||||
|
response, tasks, self.provider, self.model,
|
||||||
|
)
|
||||||
|
if should_notify and self.on_notify:
|
||||||
|
logger.info("Heartbeat: completed, delivering response")
|
||||||
|
await self.on_notify(response)
|
||||||
|
else:
|
||||||
|
logger.info("Heartbeat: silenced by post-run evaluation")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Heartbeat execution failed")
|
logger.exception("Heartbeat execution failed")
|
||||||
|
|
||||||
|
|||||||
92
nanobot/utils/evaluator.py
Normal file
92
nanobot/utils/evaluator.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Post-run evaluation for background tasks (heartbeat & cron).
|
||||||
|
|
||||||
|
After the agent executes a background task, this module makes a lightweight
|
||||||
|
LLM call to decide whether the result warrants notifying the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
|
_EVALUATE_TOOL = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "evaluate_notification",
|
||||||
|
"description": "Decide whether the user should be notified about this background task result.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"should_notify": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "true = result contains actionable/important info the user should see; false = routine or empty, safe to suppress",
|
||||||
|
},
|
||||||
|
"reason": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "One-sentence reason for the decision",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["should_notify"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a notification gate for a background agent. "
|
||||||
|
"You will be given the original task and the agent's response. "
|
||||||
|
"Call the evaluate_notification tool to decide whether the user "
|
||||||
|
"should be notified.\n\n"
|
||||||
|
"Notify when the response contains actionable information, errors, "
|
||||||
|
"completed deliverables, or anything the user explicitly asked to "
|
||||||
|
"be reminded about.\n\n"
|
||||||
|
"Suppress when the response is a routine status check with nothing "
|
||||||
|
"new, a confirmation that everything is normal, or essentially empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def evaluate_response(
|
||||||
|
response: str,
|
||||||
|
task_context: str,
|
||||||
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Decide whether a background-task result should be delivered to the user.
|
||||||
|
|
||||||
|
Uses a lightweight tool-call LLM request (same pattern as heartbeat
|
||||||
|
``_decide()``). Falls back to ``True`` (notify) on any failure so
|
||||||
|
that important messages are never silently dropped.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
llm_response = await provider.chat_with_retry(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": (
|
||||||
|
f"## Original task\n{task_context}\n\n"
|
||||||
|
f"## Agent response\n{response}"
|
||||||
|
)},
|
||||||
|
],
|
||||||
|
tools=_EVALUATE_TOOL,
|
||||||
|
model=model,
|
||||||
|
max_tokens=256,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not llm_response.has_tool_calls:
|
||||||
|
logger.warning("evaluate_response: no tool call returned, defaulting to notify")
|
||||||
|
return True
|
||||||
|
|
||||||
|
args = llm_response.tool_calls[0].arguments
|
||||||
|
should_notify = args.get("should_notify", True)
|
||||||
|
reason = args.get("reason", "")
|
||||||
|
logger.info("evaluate_response: should_notify={}, reason={}", should_notify, reason)
|
||||||
|
return bool(should_notify)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("evaluate_response failed, defaulting to notify")
|
||||||
|
return True
|
||||||
63
tests/test_evaluator.py
Normal file
63
tests/test_evaluator.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.utils.evaluator import evaluate_response
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProvider(LLMProvider):
|
||||||
|
def __init__(self, responses: list[LLMResponse]):
|
||||||
|
super().__init__()
|
||||||
|
self._responses = list(responses)
|
||||||
|
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
if self._responses:
|
||||||
|
return self._responses.pop(0)
|
||||||
|
return LLMResponse(content="", tool_calls=[])
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "test-model"
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_tool_call(should_notify: bool, reason: str = "") -> LLMResponse:
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="eval_1",
|
||||||
|
name="evaluate_notification",
|
||||||
|
arguments={"should_notify": should_notify, "reason": reason},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_should_notify_true() -> None:
|
||||||
|
provider = DummyProvider([_eval_tool_call(True, "user asked to be reminded")])
|
||||||
|
result = await evaluate_response("Task completed with results", "check emails", provider, "m")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_should_notify_false() -> None:
|
||||||
|
provider = DummyProvider([_eval_tool_call(False, "routine check, nothing new")])
|
||||||
|
result = await evaluate_response("All clear, no updates", "check status", provider, "m")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_on_error() -> None:
|
||||||
|
class FailingProvider(DummyProvider):
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
raise RuntimeError("provider down")
|
||||||
|
|
||||||
|
provider = FailingProvider([])
|
||||||
|
result = await evaluate_response("some response", "some task", provider, "m")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_call_fallback() -> None:
|
||||||
|
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
|
||||||
|
result = await evaluate_response("some response", "some task", provider, "m")
|
||||||
|
assert result is True
|
||||||
@@ -123,6 +123,98 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
|||||||
assert await service.trigger_now() is None
|
assert await service.trigger_now() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tick_notifies_when_evaluator_says_yes(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=notify -> on_notify called."""
|
||||||
|
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check deployments", encoding="utf-8")
|
||||||
|
|
||||||
|
provider = DummyProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="hb_1",
|
||||||
|
name="heartbeat",
|
||||||
|
arguments={"action": "run", "tasks": "check deployments"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
executed: list[str] = []
|
||||||
|
notified: list[str] = []
|
||||||
|
|
||||||
|
async def _on_execute(tasks: str) -> str:
|
||||||
|
executed.append(tasks)
|
||||||
|
return "deployment failed on staging"
|
||||||
|
|
||||||
|
async def _on_notify(response: str) -> None:
|
||||||
|
notified.append(response)
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
on_execute=_on_execute,
|
||||||
|
on_notify=_on_notify,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _eval_notify(*a, **kw):
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_notify)
|
||||||
|
|
||||||
|
await service._tick()
|
||||||
|
assert executed == ["check deployments"]
|
||||||
|
assert notified == ["deployment failed on staging"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=silent -> on_notify NOT called."""
|
||||||
|
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check status", encoding="utf-8")
|
||||||
|
|
||||||
|
provider = DummyProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="hb_1",
|
||||||
|
name="heartbeat",
|
||||||
|
arguments={"action": "run", "tasks": "check status"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
executed: list[str] = []
|
||||||
|
notified: list[str] = []
|
||||||
|
|
||||||
|
async def _on_execute(tasks: str) -> str:
|
||||||
|
executed.append(tasks)
|
||||||
|
return "everything is fine, no issues"
|
||||||
|
|
||||||
|
async def _on_notify(response: str) -> None:
|
||||||
|
notified.append(response)
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
on_execute=_on_execute,
|
||||||
|
on_notify=_on_notify,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _eval_silent(*a, **kw):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_silent)
|
||||||
|
|
||||||
|
await service._tick()
|
||||||
|
assert executed == ["check status"]
|
||||||
|
assert notified == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||||
provider = DummyProvider([
|
provider = DummyProvider([
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType, SimpleNamespace
|
from types import ModuleType, SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.tools.mcp import MCPToolWrapper
|
from nanobot.agent.tools.mcp import MCPToolWrapper, connect_mcp_servers
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.config.schema import MCPServerConfig
|
||||||
|
|
||||||
|
|
||||||
class _FakeTextContent:
|
class _FakeTextContent:
|
||||||
@@ -14,12 +17,63 @@ class _FakeTextContent:
|
|||||||
self.text = text
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_mcp_runtime() -> dict[str, object | None]:
|
||||||
|
return {"session": None}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
|
def _fake_mcp_module(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, fake_mcp_runtime: dict[str, object | None]
|
||||||
|
) -> None:
|
||||||
mod = ModuleType("mcp")
|
mod = ModuleType("mcp")
|
||||||
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
||||||
|
|
||||||
|
class _FakeStdioServerParameters:
|
||||||
|
def __init__(self, command: str, args: list[str], env: dict | None = None) -> None:
|
||||||
|
self.command = command
|
||||||
|
self.args = args
|
||||||
|
self.env = env
|
||||||
|
|
||||||
|
class _FakeClientSession:
|
||||||
|
def __init__(self, _read: object, _write: object) -> None:
|
||||||
|
self._session = fake_mcp_runtime["session"]
|
||||||
|
|
||||||
|
async def __aenter__(self) -> object:
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_stdio_client(_params: object):
|
||||||
|
yield object(), object()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_sse_client(_url: str, httpx_client_factory=None):
|
||||||
|
yield object(), object()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_streamable_http_client(_url: str, http_client=None):
|
||||||
|
yield object(), object(), object()
|
||||||
|
|
||||||
|
mod.ClientSession = _FakeClientSession
|
||||||
|
mod.StdioServerParameters = _FakeStdioServerParameters
|
||||||
monkeypatch.setitem(sys.modules, "mcp", mod)
|
monkeypatch.setitem(sys.modules, "mcp", mod)
|
||||||
|
|
||||||
|
client_mod = ModuleType("mcp.client")
|
||||||
|
stdio_mod = ModuleType("mcp.client.stdio")
|
||||||
|
stdio_mod.stdio_client = _fake_stdio_client
|
||||||
|
sse_mod = ModuleType("mcp.client.sse")
|
||||||
|
sse_mod.sse_client = _fake_sse_client
|
||||||
|
streamable_http_mod = ModuleType("mcp.client.streamable_http")
|
||||||
|
streamable_http_mod.streamable_http_client = _fake_streamable_http_client
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp.client", client_mod)
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp.client.stdio", stdio_mod)
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp.client.sse", sse_mod)
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", streamable_http_mod)
|
||||||
|
|
||||||
|
|
||||||
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
||||||
tool_def = SimpleNamespace(
|
tool_def = SimpleNamespace(
|
||||||
@@ -97,3 +151,132 @@ async def test_execute_handles_generic_exception() -> None:
|
|||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
|
|
||||||
assert result == "(MCP tool call failed: RuntimeError)"
|
assert result == "(MCP tool call failed: RuntimeError)"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_def(name: str) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
name=name,
|
||||||
|
description=f"{name} tool",
|
||||||
|
inputSchema={"type": "object", "properties": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_session(tool_names: list[str]) -> SimpleNamespace:
|
||||||
|
async def initialize() -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_tools() -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(tools=[_make_tool_def(name) for name in tool_names])
|
||||||
|
|
||||||
|
return SimpleNamespace(initialize=initialize, list_tools=list_tools)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
||||||
|
fake_mcp_runtime: dict[str, object | None],
|
||||||
|
) -> None:
|
||||||
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
try:
|
||||||
|
await connect_mcp_servers(
|
||||||
|
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||||
|
registry,
|
||||||
|
stack,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
||||||
|
fake_mcp_runtime: dict[str, object | None],
|
||||||
|
) -> None:
|
||||||
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
try:
|
||||||
|
await connect_mcp_servers(
|
||||||
|
{"test": MCPServerConfig(command="fake")},
|
||||||
|
registry,
|
||||||
|
stack,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
||||||
|
fake_mcp_runtime: dict[str, object | None],
|
||||||
|
) -> None:
|
||||||
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
try:
|
||||||
|
await connect_mcp_servers(
|
||||||
|
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||||
|
registry,
|
||||||
|
stack,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
||||||
|
fake_mcp_runtime: dict[str, object | None],
|
||||||
|
) -> None:
|
||||||
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
try:
|
||||||
|
await connect_mcp_servers(
|
||||||
|
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||||
|
registry,
|
||||||
|
stack,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||||
|
fake_mcp_runtime: dict[str, object | None], monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
fake_mcp_runtime["session"] = _make_fake_session(["demo"])
|
||||||
|
registry = ToolRegistry()
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
def _warning(message: str, *args: object) -> None:
|
||||||
|
warnings.append(message.format(*args))
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||||
|
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
try:
|
||||||
|
await connect_mcp_servers(
|
||||||
|
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||||
|
registry,
|
||||||
|
stack,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == []
|
||||||
|
assert warnings
|
||||||
|
assert "enabledTools entries not found: unknown" in warnings[-1]
|
||||||
|
assert "Available raw names: demo" in warnings[-1]
|
||||||
|
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||||
|
|||||||
@@ -647,3 +647,19 @@ async def test_forward_command_does_not_inject_reply_context() -> None:
|
|||||||
|
|
||||||
assert len(handled) == 1
|
assert len(handled) == 1
|
||||||
assert handled[0]["content"] == "/new"
|
assert handled[0]["content"] == "/new"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_help_includes_restart_command() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
update = _make_telegram_update(text="/help", chat_type="private")
|
||||||
|
update.message.reply_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel._on_help(update, None)
|
||||||
|
|
||||||
|
update.message.reply_text.assert_awaited_once()
|
||||||
|
help_text = update.message.reply_text.await_args.args[0]
|
||||||
|
assert "/restart" in help_text
|
||||||
|
|||||||
Reference in New Issue
Block a user