refactor(agent): unify process_direct to return OutboundMessage
Merge process_direct() and process_direct_outbound() into a single interface returning OutboundMessage | None. This eliminates the dual-path detection logic in CLI single-message mode that relied on inspect.iscoroutinefunction to distinguish between the two APIs. Extract status rendering into a pure function build_status_content() in utils/helpers.py, decoupling it from AgentLoop internals. Made-with: Cursor
This commit is contained in:
@@ -27,6 +27,7 @@ from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.utils.helpers import build_status_content
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
@@ -185,48 +186,25 @@ class AgentLoop:
|
||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
def _build_status_content(self, session: Session) -> str:
|
||||
"""Build a human-readable runtime status snapshot."""
|
||||
history = session.get_history(max_messages=0)
|
||||
msg_count = len(history)
|
||||
|
||||
uptime_s = int(time.time() - self._start_time)
|
||||
uptime = (
|
||||
f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m"
|
||||
if uptime_s >= 3600
|
||||
else f"{uptime_s // 60}m {uptime_s % 60}s"
|
||||
)
|
||||
|
||||
last_in = self._last_usage.get("prompt_tokens", 0)
|
||||
last_out = self._last_usage.get("completion_tokens", 0)
|
||||
|
||||
ctx_used = 0
|
||||
try:
|
||||
ctx_used, _ = self.memory_consolidator.estimate_session_prompt_tokens(session)
|
||||
except Exception:
|
||||
ctx_used = 0
|
||||
if ctx_used <= 0:
|
||||
ctx_used = last_in
|
||||
ctx_total_tokens = max(self.context_window_tokens, 0)
|
||||
ctx_pct = int((ctx_used / ctx_total_tokens) * 100) if ctx_total_tokens > 0 else 0
|
||||
ctx_used_str = f"{ctx_used // 1000}k" if ctx_used >= 1000 else str(ctx_used)
|
||||
ctx_total_str = f"{ctx_total_tokens // 1024}k" if ctx_total_tokens > 0 else "n/a"
|
||||
|
||||
return "\n".join([
|
||||
f"🐈 nanobot v{__version__}",
|
||||
f"🧠 Model: {self.model}",
|
||||
f"📊 Tokens: {last_in} in / {last_out} out",
|
||||
f"📚 Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"💬 Session: {msg_count} messages",
|
||||
f"⏱ Uptime: {uptime}",
|
||||
])
|
||||
|
||||
def _status_response(self, msg: InboundMessage, session: Session) -> OutboundMessage:
|
||||
"""Build an outbound status message for a session."""
|
||||
ctx_est = 0
|
||||
try:
|
||||
ctx_est, _ = self.memory_consolidator.estimate_session_prompt_tokens(session)
|
||||
except Exception:
|
||||
pass
|
||||
if ctx_est <= 0:
|
||||
ctx_est = self._last_usage.get("prompt_tokens", 0)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=self._build_status_content(session),
|
||||
content=build_status_content(
|
||||
version=__version__, model=self.model,
|
||||
start_time=self._start_time, last_usage=self._last_usage,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
session_msg_count=len(session.get_history(max_messages=0)),
|
||||
context_tokens_estimate=ctx_est,
|
||||
),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
@@ -607,7 +585,7 @@ class AgentLoop:
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def process_direct_outbound(
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
session_key: str = "cli:direct",
|
||||
@@ -619,21 +597,3 @@ class AgentLoop:
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> str:
|
||||
"""Process a message directly (for CLI or cron usage)."""
|
||||
response = await self.process_direct_outbound(
|
||||
content,
|
||||
session_key=session_key,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
return response.content if response else ""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import inspect
|
||||
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
@@ -579,7 +579,7 @@ def gateway(
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_token = cron_tool.set_cron_context(True)
|
||||
try:
|
||||
response = await agent.process_direct(
|
||||
resp = await agent.process_direct(
|
||||
reminder_note,
|
||||
session_key=f"cron:{job.id}",
|
||||
channel=job.payload.channel or "cli",
|
||||
@@ -589,6 +589,8 @@ def gateway(
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
|
||||
response = resp.content if resp else ""
|
||||
|
||||
message_tool = agent.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
return response
|
||||
@@ -634,13 +636,14 @@ def gateway(
|
||||
async def _silent(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
return await agent.process_direct(
|
||||
resp = await agent.process_direct(
|
||||
tasks,
|
||||
session_key="heartbeat",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
on_progress=_silent,
|
||||
)
|
||||
return resp.content if resp else ""
|
||||
|
||||
async def on_heartbeat_notify(response: str) -> None:
|
||||
"""Deliver a heartbeat response to the user's channel."""
|
||||
@@ -768,27 +771,15 @@ def agent(
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
with _thinking:
|
||||
direct_outbound = getattr(agent_loop, "process_direct_outbound", None)
|
||||
if inspect.iscoroutinefunction(direct_outbound):
|
||||
response = await agent_loop.process_direct_outbound(
|
||||
message,
|
||||
session_id,
|
||||
on_progress=_cli_progress,
|
||||
)
|
||||
response_content = response.content if response else ""
|
||||
response_meta = response.metadata if response else None
|
||||
else:
|
||||
response_content = await agent_loop.process_direct(
|
||||
message,
|
||||
session_id,
|
||||
on_progress=_cli_progress,
|
||||
)
|
||||
response_meta = None
|
||||
response = await agent_loop.process_direct(
|
||||
message, session_id, on_progress=_cli_progress,
|
||||
)
|
||||
_thinking = None
|
||||
kwargs = {"render_markdown": markdown}
|
||||
if response_meta is not None:
|
||||
kwargs["metadata"] = response_meta
|
||||
_print_agent_response(response_content, **kwargs)
|
||||
_print_agent_response(
|
||||
response.content if response else "",
|
||||
render_markdown=markdown,
|
||||
metadata=response.metadata if response else None,
|
||||
)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_once())
|
||||
|
||||
@@ -192,6 +192,39 @@ def estimate_prompt_tokens_chain(
|
||||
return 0, "none"
|
||||
|
||||
|
||||
def build_status_content(
|
||||
*,
|
||||
version: str,
|
||||
model: str,
|
||||
start_time: float,
|
||||
last_usage: dict[str, int],
|
||||
context_window_tokens: int,
|
||||
session_msg_count: int,
|
||||
context_tokens_estimate: int,
|
||||
) -> str:
|
||||
"""Build a human-readable runtime status snapshot."""
|
||||
uptime_s = int(time.time() - start_time)
|
||||
uptime = (
|
||||
f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m"
|
||||
if uptime_s >= 3600
|
||||
else f"{uptime_s // 60}m {uptime_s % 60}s"
|
||||
)
|
||||
last_in = last_usage.get("prompt_tokens", 0)
|
||||
last_out = last_usage.get("completion_tokens", 0)
|
||||
ctx_total = max(context_window_tokens, 0)
|
||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
||||
return "\n".join([
|
||||
f"\U0001f408 nanobot v{version}",
|
||||
f"\U0001f9e0 Model: {model}",
|
||||
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||
f"\u23f1 Uptime: {uptime}",
|
||||
])
|
||||
|
||||
|
||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||
from importlib.resources import files as pkg_files
|
||||
|
||||
Reference in New Issue
Block a user