refactor(agent): unify process_direct to return OutboundMessage
Merge process_direct() and process_direct_outbound() into a single interface returning OutboundMessage | None. This eliminates the dual-path detection logic in CLI single-message mode that relied on inspect.iscoroutinefunction to distinguish between the two APIs. Extract status rendering into a pure function build_status_content() in utils/helpers.py, decoupling it from AgentLoop internals. Made-with: Cursor
This commit is contained in:
@@ -27,6 +27,7 @@ from nanobot.agent.tools.shell import ExecTool
|
|||||||
from nanobot.agent.tools.spawn import SpawnTool
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
from nanobot.utils.helpers import build_status_content
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
@@ -185,48 +186,25 @@ class AgentLoop:
|
|||||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||||
|
|
||||||
def _build_status_content(self, session: Session) -> str:
|
|
||||||
"""Build a human-readable runtime status snapshot."""
|
|
||||||
history = session.get_history(max_messages=0)
|
|
||||||
msg_count = len(history)
|
|
||||||
|
|
||||||
uptime_s = int(time.time() - self._start_time)
|
|
||||||
uptime = (
|
|
||||||
f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m"
|
|
||||||
if uptime_s >= 3600
|
|
||||||
else f"{uptime_s // 60}m {uptime_s % 60}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
last_in = self._last_usage.get("prompt_tokens", 0)
|
|
||||||
last_out = self._last_usage.get("completion_tokens", 0)
|
|
||||||
|
|
||||||
ctx_used = 0
|
|
||||||
try:
|
|
||||||
ctx_used, _ = self.memory_consolidator.estimate_session_prompt_tokens(session)
|
|
||||||
except Exception:
|
|
||||||
ctx_used = 0
|
|
||||||
if ctx_used <= 0:
|
|
||||||
ctx_used = last_in
|
|
||||||
ctx_total_tokens = max(self.context_window_tokens, 0)
|
|
||||||
ctx_pct = int((ctx_used / ctx_total_tokens) * 100) if ctx_total_tokens > 0 else 0
|
|
||||||
ctx_used_str = f"{ctx_used // 1000}k" if ctx_used >= 1000 else str(ctx_used)
|
|
||||||
ctx_total_str = f"{ctx_total_tokens // 1024}k" if ctx_total_tokens > 0 else "n/a"
|
|
||||||
|
|
||||||
return "\n".join([
|
|
||||||
f"🐈 nanobot v{__version__}",
|
|
||||||
f"🧠 Model: {self.model}",
|
|
||||||
f"📊 Tokens: {last_in} in / {last_out} out",
|
|
||||||
f"📚 Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
|
||||||
f"💬 Session: {msg_count} messages",
|
|
||||||
f"⏱ Uptime: {uptime}",
|
|
||||||
])
|
|
||||||
|
|
||||||
def _status_response(self, msg: InboundMessage, session: Session) -> OutboundMessage:
|
def _status_response(self, msg: InboundMessage, session: Session) -> OutboundMessage:
|
||||||
"""Build an outbound status message for a session."""
|
"""Build an outbound status message for a session."""
|
||||||
|
ctx_est = 0
|
||||||
|
try:
|
||||||
|
ctx_est, _ = self.memory_consolidator.estimate_session_prompt_tokens(session)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if ctx_est <= 0:
|
||||||
|
ctx_est = self._last_usage.get("prompt_tokens", 0)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
content=self._build_status_content(session),
|
content=build_status_content(
|
||||||
|
version=__version__, model=self.model,
|
||||||
|
start_time=self._start_time, last_usage=self._last_usage,
|
||||||
|
context_window_tokens=self.context_window_tokens,
|
||||||
|
session_msg_count=len(session.get_history(max_messages=0)),
|
||||||
|
context_tokens_estimate=ctx_est,
|
||||||
|
),
|
||||||
metadata={"render_as": "text"},
|
metadata={"render_as": "text"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -607,7 +585,7 @@ class AgentLoop:
|
|||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
async def process_direct_outbound(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
session_key: str = "cli:direct",
|
session_key: str = "cli:direct",
|
||||||
@@ -619,21 +597,3 @@ class AgentLoop:
|
|||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||||
|
|
||||||
async def process_direct(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
session_key: str = "cli:direct",
|
|
||||||
channel: str = "cli",
|
|
||||||
chat_id: str = "direct",
|
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Process a message directly (for CLI or cron usage)."""
|
|
||||||
response = await self.process_direct_outbound(
|
|
||||||
content,
|
|
||||||
session_key=session_key,
|
|
||||||
channel=channel,
|
|
||||||
chat_id=chat_id,
|
|
||||||
on_progress=on_progress,
|
|
||||||
)
|
|
||||||
return response.content if response else ""
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import signal
|
import signal
|
||||||
@@ -579,7 +579,7 @@ def gateway(
|
|||||||
if isinstance(cron_tool, CronTool):
|
if isinstance(cron_tool, CronTool):
|
||||||
cron_token = cron_tool.set_cron_context(True)
|
cron_token = cron_tool.set_cron_context(True)
|
||||||
try:
|
try:
|
||||||
response = await agent.process_direct(
|
resp = await agent.process_direct(
|
||||||
reminder_note,
|
reminder_note,
|
||||||
session_key=f"cron:{job.id}",
|
session_key=f"cron:{job.id}",
|
||||||
channel=job.payload.channel or "cli",
|
channel=job.payload.channel or "cli",
|
||||||
@@ -589,6 +589,8 @@ def gateway(
|
|||||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||||
cron_tool.reset_cron_context(cron_token)
|
cron_tool.reset_cron_context(cron_token)
|
||||||
|
|
||||||
|
response = resp.content if resp else ""
|
||||||
|
|
||||||
message_tool = agent.tools.get("message")
|
message_tool = agent.tools.get("message")
|
||||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||||
return response
|
return response
|
||||||
@@ -634,13 +636,14 @@ def gateway(
|
|||||||
async def _silent(*_args, **_kwargs):
|
async def _silent(*_args, **_kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return await agent.process_direct(
|
resp = await agent.process_direct(
|
||||||
tasks,
|
tasks,
|
||||||
session_key="heartbeat",
|
session_key="heartbeat",
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
on_progress=_silent,
|
on_progress=_silent,
|
||||||
)
|
)
|
||||||
|
return resp.content if resp else ""
|
||||||
|
|
||||||
async def on_heartbeat_notify(response: str) -> None:
|
async def on_heartbeat_notify(response: str) -> None:
|
||||||
"""Deliver a heartbeat response to the user's channel."""
|
"""Deliver a heartbeat response to the user's channel."""
|
||||||
@@ -768,27 +771,15 @@ def agent(
|
|||||||
nonlocal _thinking
|
nonlocal _thinking
|
||||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||||
with _thinking:
|
with _thinking:
|
||||||
direct_outbound = getattr(agent_loop, "process_direct_outbound", None)
|
response = await agent_loop.process_direct(
|
||||||
if inspect.iscoroutinefunction(direct_outbound):
|
message, session_id, on_progress=_cli_progress,
|
||||||
response = await agent_loop.process_direct_outbound(
|
|
||||||
message,
|
|
||||||
session_id,
|
|
||||||
on_progress=_cli_progress,
|
|
||||||
)
|
)
|
||||||
response_content = response.content if response else ""
|
|
||||||
response_meta = response.metadata if response else None
|
|
||||||
else:
|
|
||||||
response_content = await agent_loop.process_direct(
|
|
||||||
message,
|
|
||||||
session_id,
|
|
||||||
on_progress=_cli_progress,
|
|
||||||
)
|
|
||||||
response_meta = None
|
|
||||||
_thinking = None
|
_thinking = None
|
||||||
kwargs = {"render_markdown": markdown}
|
_print_agent_response(
|
||||||
if response_meta is not None:
|
response.content if response else "",
|
||||||
kwargs["metadata"] = response_meta
|
render_markdown=markdown,
|
||||||
_print_agent_response(response_content, **kwargs)
|
metadata=response.metadata if response else None,
|
||||||
|
)
|
||||||
await agent_loop.close_mcp()
|
await agent_loop.close_mcp()
|
||||||
|
|
||||||
asyncio.run(run_once())
|
asyncio.run(run_once())
|
||||||
|
|||||||
@@ -192,6 +192,39 @@ def estimate_prompt_tokens_chain(
|
|||||||
return 0, "none"
|
return 0, "none"
|
||||||
|
|
||||||
|
|
||||||
|
def build_status_content(
|
||||||
|
*,
|
||||||
|
version: str,
|
||||||
|
model: str,
|
||||||
|
start_time: float,
|
||||||
|
last_usage: dict[str, int],
|
||||||
|
context_window_tokens: int,
|
||||||
|
session_msg_count: int,
|
||||||
|
context_tokens_estimate: int,
|
||||||
|
) -> str:
|
||||||
|
"""Build a human-readable runtime status snapshot."""
|
||||||
|
uptime_s = int(time.time() - start_time)
|
||||||
|
uptime = (
|
||||||
|
f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m"
|
||||||
|
if uptime_s >= 3600
|
||||||
|
else f"{uptime_s // 60}m {uptime_s % 60}s"
|
||||||
|
)
|
||||||
|
last_in = last_usage.get("prompt_tokens", 0)
|
||||||
|
last_out = last_usage.get("completion_tokens", 0)
|
||||||
|
ctx_total = max(context_window_tokens, 0)
|
||||||
|
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||||
|
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||||
|
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
||||||
|
return "\n".join([
|
||||||
|
f"\U0001f408 nanobot v{version}",
|
||||||
|
f"\U0001f9e0 Model: {model}",
|
||||||
|
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
|
||||||
|
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||||
|
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||||
|
f"\u23f1 Uptime: {uptime}",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.cli.commands import _make_provider, app
|
from nanobot.cli.commands import _make_provider, app
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
@@ -345,7 +346,9 @@ def mock_agent_runtime(tmp_path):
|
|||||||
|
|
||||||
agent_loop = MagicMock()
|
agent_loop = MagicMock()
|
||||||
agent_loop.channels_config = None
|
agent_loop.channels_config = None
|
||||||
agent_loop.process_direct = AsyncMock(return_value="mock-response")
|
agent_loop.process_direct = AsyncMock(
|
||||||
|
return_value=OutboundMessage(channel="cli", chat_id="direct", content="mock-response"),
|
||||||
|
)
|
||||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||||
mock_agent_loop_cls.return_value = agent_loop
|
mock_agent_loop_cls.return_value = agent_loop
|
||||||
|
|
||||||
@@ -382,7 +385,9 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_
|
|||||||
mock_agent_runtime["config"].workspace_path
|
mock_agent_runtime["config"].workspace_path
|
||||||
)
|
)
|
||||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||||
mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
|
mock_agent_runtime["print_response"].assert_called_once_with(
|
||||||
|
"mock-response", render_markdown=True, metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
||||||
@@ -418,8 +423,8 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs) -> str:
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return "ok"
|
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -175,14 +175,14 @@ class TestRestartCommand:
|
|||||||
assert "Context: 1k/64k (1%)" in response.content
|
assert "Context: 1k/64k (1%)" in response.content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_direct_outbound_preserves_render_metadata(self):
|
async def test_process_direct_preserves_render_metadata(self):
|
||||||
loop, _bus = _make_loop()
|
loop, _bus = _make_loop()
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.get_history.return_value = []
|
session.get_history.return_value = []
|
||||||
loop.sessions.get_or_create.return_value = session
|
loop.sessions.get_or_create.return_value = session
|
||||||
loop.subagents.get_running_count.return_value = 0
|
loop.subagents.get_running_count.return_value = 0
|
||||||
|
|
||||||
response = await loop.process_direct_outbound("/status", session_key="cli:test")
|
response = await loop.process_direct("/status", session_key="cli:test")
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.metadata == {"render_as": "text"}
|
assert response.metadata == {"render_as": "text"}
|
||||||
|
|||||||
Reference in New Issue
Block a user