Merge branch 'main' into pr-1900
This commit is contained in:
@@ -64,7 +64,7 @@
|
|||||||
|
|
||||||
## Key Features of nanobot:
|
## Key Features of nanobot:
|
||||||
|
|
||||||
🪶 **Ultra-Lightweight**: Just ~4,000 lines of core agent code — 99% smaller than Clawdbot.
|
🪶 **Ultra-Lightweight**: A super lightweight implementation of OpenClaw — 99% smaller, significantly faster.
|
||||||
|
|
||||||
🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
|
🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
|
||||||
|
|
||||||
@@ -502,7 +502,8 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
"appSecret": "xxx",
|
"appSecret": "xxx",
|
||||||
"encryptKey": "",
|
"encryptKey": "",
|
||||||
"verificationToken": "",
|
"verificationToken": "",
|
||||||
"allowFrom": ["ou_YOUR_OPEN_ID"]
|
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
||||||
|
"groupPolicy": "mention"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -510,6 +511,7 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
|
|
||||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||||
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||||
|
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
||||||
|
|
||||||
**3. Run**
|
**3. Run**
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
@@ -43,7 +45,7 @@ class AgentLoop:
|
|||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOOL_RESULT_MAX_CHARS = 500
|
_TOOL_RESULT_MAX_CHARS = 16_000
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -256,8 +258,11 @@ class AgentLoop:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if msg.content.strip().lower() == "/stop":
|
cmd = msg.content.strip().lower()
|
||||||
|
if cmd == "/stop":
|
||||||
await self._handle_stop(msg)
|
await self._handle_stop(msg)
|
||||||
|
elif cmd == "/restart":
|
||||||
|
await self._handle_restart(msg)
|
||||||
else:
|
else:
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||||
@@ -274,11 +279,23 @@ class AgentLoop:
|
|||||||
pass
|
pass
|
||||||
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||||
total = cancelled + sub_cancelled
|
total = cancelled + sub_cancelled
|
||||||
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
async def _handle_restart(self, msg: InboundMessage) -> None:
|
||||||
|
"""Restart the process in-place via os.execv."""
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _do_restart():
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
os.execv(sys.executable, [sys.executable] + sys.argv)
|
||||||
|
|
||||||
|
asyncio.create_task(_do_restart())
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message under the global lock."""
|
"""Process a message under the global lock."""
|
||||||
async with self._processing_lock:
|
async with self._processing_lock:
|
||||||
@@ -373,9 +390,16 @@ class AgentLoop:
|
|||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="New session started.")
|
content="New session started.")
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
lines = [
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
"🐈 nanobot commands:",
|
||||||
|
"/new — Start a new conversation",
|
||||||
|
"/stop — Stop the current task",
|
||||||
|
"/restart — Restart the bot",
|
||||||
|
"/help — Show available commands",
|
||||||
|
]
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
||||||
|
)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class MemoryStore:
|
|||||||
],
|
],
|
||||||
tools=_SAVE_MEMORY_TOOL,
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
model=model,
|
model=model,
|
||||||
|
tool_choice="required",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
|
|||||||
@@ -352,6 +352,27 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Feishu bot stopped")
|
logger.info("Feishu bot stopped")
|
||||||
|
|
||||||
|
def _is_bot_mentioned(self, message: Any) -> bool:
|
||||||
|
"""Check if the bot is @mentioned in the message."""
|
||||||
|
raw_content = message.content or ""
|
||||||
|
if "@_all" in raw_content:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for mention in getattr(message, "mentions", None) or []:
|
||||||
|
mid = getattr(mention, "id", None)
|
||||||
|
if not mid:
|
||||||
|
continue
|
||||||
|
# Bot mentions have no user_id (None or "") but a valid open_id
|
||||||
|
if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_group_message_for_bot(self, message: Any) -> bool:
|
||||||
|
"""Allow group messages when policy is open or bot is @mentioned."""
|
||||||
|
if self.config.group_policy == "open":
|
||||||
|
return True
|
||||||
|
return self._is_bot_mentioned(message)
|
||||||
|
|
||||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||||
@@ -893,6 +914,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
chat_type = message.chat_type
|
chat_type = message.chat_type
|
||||||
msg_type = message.message_type
|
msg_type = message.message_type
|
||||||
|
|
||||||
|
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
||||||
|
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||||
|
return
|
||||||
|
|
||||||
# Add reaction
|
# Add reaction
|
||||||
await self._add_reaction(message_id, self.config.react_emoji)
|
await self._add_reaction(message_id, self.config.react_emoji)
|
||||||
|
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
BotCommand("new", "Start a new conversation"),
|
BotCommand("new", "Start a new conversation"),
|
||||||
BotCommand("stop", "Stop the current task"),
|
BotCommand("stop", "Stop the current task"),
|
||||||
BotCommand("help", "Show available commands"),
|
BotCommand("help", "Show available commands"),
|
||||||
|
BotCommand("restart", "Restart the bot"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, config: TelegramConfig, bus: MessageBus):
|
def __init__(self, config: TelegramConfig, bus: MessageBus):
|
||||||
@@ -221,6 +222,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||||
|
self._app.add_handler(CommandHandler("restart", self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||||
|
|
||||||
# Add message handler for text, photos, voice, documents
|
# Add message handler for text, photos, voice, documents
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class FeishuConfig(Base):
|
|||||||
react_emoji: str = (
|
react_emoji: str = (
|
||||||
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
||||||
)
|
)
|
||||||
|
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned, "open" responds to all
|
||||||
|
|
||||||
|
|
||||||
class DingTalkConfig(Base):
|
class DingTalkConfig(Base):
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
@@ -106,7 +107,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
payload["tools"] = tools
|
payload["tools"] = tools
|
||||||
payload["tool_choice"] = "auto"
|
payload["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@@ -118,6 +119,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request to Azure OpenAI.
|
Send a chat completion request to Azure OpenAI.
|
||||||
@@ -137,7 +139,8 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
url = self._build_chat_url(deployment_name)
|
url = self._build_chat_url(deployment_name)
|
||||||
headers = self._build_headers()
|
headers = self._build_headers()
|
||||||
payload = self._prepare_request_payload(
|
payload = self._prepare_request_payload(
|
||||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
|
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ class LLMProvider(ABC):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request.
|
Send a chat completion request.
|
||||||
@@ -176,6 +177,7 @@ class LLMProvider(ABC):
|
|||||||
model: Model identifier (provider-specific).
|
model: Model identifier (provider-specific).
|
||||||
max_tokens: Maximum tokens in response.
|
max_tokens: Maximum tokens in response.
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
|
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMResponse with content and/or tool calls.
|
LLMResponse with content and/or tool calls.
|
||||||
@@ -195,6 +197,7 @@ class LLMProvider(ABC):
|
|||||||
max_tokens: object = _SENTINEL,
|
max_tokens: object = _SENTINEL,
|
||||||
temperature: object = _SENTINEL,
|
temperature: object = _SENTINEL,
|
||||||
reasoning_effort: object = _SENTINEL,
|
reasoning_effort: object = _SENTINEL,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat() with retry on transient provider failures.
|
"""Call chat() with retry on transient provider failures.
|
||||||
|
|
||||||
@@ -218,6 +221,7 @@ class LLMProvider(ABC):
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
reasoning_effort=reasoning_effort,
|
reasoning_effort=reasoning_effort,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
@@ -250,6 +254,7 @@ class LLMProvider(ABC):
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
reasoning_effort=reasoning_effort,
|
reasoning_effort=reasoning_effort,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ class CustomProvider(LLMProvider):
|
|||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None) -> LLMResponse:
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_empty_content(messages),
|
||||||
@@ -35,7 +36,7 @@ class CustomProvider(LLMProvider):
|
|||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice="auto")
|
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
||||||
try:
|
try:
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request via LiteLLM.
|
Send a chat completion request via LiteLLM.
|
||||||
@@ -267,7 +268,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = "auto"
|
kwargs["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(**kwargs)
|
response = await acompletion(**kwargs)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
system_prompt, input_items = _convert_messages(messages)
|
system_prompt, input_items = _convert_messages(messages)
|
||||||
@@ -48,7 +49,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
"text": {"verbosity": "medium"},
|
"text": {"verbosity": "medium"},
|
||||||
"include": ["reasoning.encrypted_content"],
|
"include": ["reasoning.encrypted_content"],
|
||||||
"prompt_cache_key": _prompt_cache_key(messages),
|
"prompt_cache_key": _prompt_cache_key(messages),
|
||||||
"tool_choice": "auto",
|
"tool_choice": tool_choice or "auto",
|
||||||
"parallel_tool_calls": True,
|
"parallel_tool_calls": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
wecom = [
|
wecom = [
|
||||||
"wecom-aibot-sdk-python @ git+https://github.com/chengyongru/wecom_aibot_sdk.git@v0.1.2",
|
"wecom-aibot-sdk-python>=0.1.2",
|
||||||
]
|
]
|
||||||
matrix = [
|
matrix = [
|
||||||
"matrix-nio[e2e]>=0.25.2",
|
"matrix-nio[e2e]>=0.25.2",
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
|
|||||||
|
|
||||||
def _mk_loop() -> AgentLoop:
|
def _mk_loop() -> AgentLoop:
|
||||||
loop = AgentLoop.__new__(AgentLoop)
|
loop = AgentLoop.__new__(AgentLoop)
|
||||||
loop._TOOL_RESULT_MAX_CHARS = 500
|
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
|
|
||||||
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
|
|||||||
skip=0,
|
skip=0,
|
||||||
)
|
)
|
||||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_turn_keeps_tool_results_under_16k() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(key="test:tool-result")
|
||||||
|
content = "x" * 12_000
|
||||||
|
|
||||||
|
loop._save_turn(
|
||||||
|
session,
|
||||||
|
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
|
||||||
|
skip=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session.messages[0]["content"] == content
|
||||||
|
|||||||
76
tests/test_restart_command.py
Normal file
76
tests/test_restart_command.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Tests for /restart slash command."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop():
|
||||||
|
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
workspace = MagicMock()
|
||||||
|
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager"):
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||||
|
return loop, bus
|
||||||
|
|
||||||
|
|
||||||
|
class TestRestartCommand:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_restart_sends_message_and_calls_execv(self):
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.os.execv") as mock_execv:
|
||||||
|
await loop._handle_restart(msg)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "Restarting" in out.content
|
||||||
|
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
mock_execv.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_restart_intercepted_in_run_loop(self):
|
||||||
|
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
|
||||||
|
|
||||||
|
with patch.object(loop, "_handle_restart") as mock_handle:
|
||||||
|
mock_handle.return_value = None
|
||||||
|
await bus.publish_inbound(msg)
|
||||||
|
|
||||||
|
loop._running = True
|
||||||
|
run_task = asyncio.create_task(loop.run())
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
loop._running = False
|
||||||
|
run_task.cancel()
|
||||||
|
try:
|
||||||
|
await run_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_handle.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_help_includes_restart(self):
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
|
||||||
|
|
||||||
|
response = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "/restart" in response.content
|
||||||
Reference in New Issue
Block a user