Merge branch 'main' into pr-1916
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
.worktrees/
|
.worktrees/
|
||||||
.assets
|
.assets
|
||||||
|
.docs
|
||||||
.env
|
.env
|
||||||
*.pyc
|
*.pyc
|
||||||
dist/
|
dist/
|
||||||
@@ -7,7 +8,7 @@ build/
|
|||||||
docs/
|
docs/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
*.egg
|
*.egg
|
||||||
*.pyc
|
*.pycs
|
||||||
*.pyo
|
*.pyo
|
||||||
*.pyd
|
*.pyd
|
||||||
*.pyw
|
*.pyw
|
||||||
|
|||||||
@@ -502,7 +502,8 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
"appSecret": "xxx",
|
"appSecret": "xxx",
|
||||||
"encryptKey": "",
|
"encryptKey": "",
|
||||||
"verificationToken": "",
|
"verificationToken": "",
|
||||||
"allowFrom": ["ou_YOUR_OPEN_ID"]
|
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
||||||
|
"groupPolicy": "mention"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -510,6 +511,7 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
|
|
||||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||||
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||||
|
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
||||||
|
|
||||||
**3. Run**
|
**3. Run**
|
||||||
|
|
||||||
@@ -756,15 +758,17 @@ Config file: `~/.nanobot/config.json`
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||||
|
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
|
||||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||||
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
|
||||||
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
|
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
|
||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
||||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||||
|
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
|
||||||
|
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
||||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||||
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||||
@@ -774,7 +778,6 @@ Config file: `~/.nanobot/config.json`
|
|||||||
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
||||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||||
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
||||||
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
|
|
||||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class AgentLoop:
|
|||||||
await self._mcp_stack.__aenter__()
|
await self._mcp_stack.__aenter__()
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||||
self._mcp_connected = True
|
self._mcp_connected = True
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
if self._mcp_stack:
|
if self._mcp_stack:
|
||||||
try:
|
try:
|
||||||
@@ -292,7 +292,9 @@ class AgentLoop:
|
|||||||
|
|
||||||
async def _do_restart():
|
async def _do_restart():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
os.execv(sys.executable, [sys.executable] + sys.argv)
|
# Use -m nanobot instead of sys.argv[0] for Windows compatibility
|
||||||
|
# (sys.argv[0] may be just "nanobot" without full path on Windows)
|
||||||
|
os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
|
||||||
|
|
||||||
asyncio.create_task(_do_restart())
|
asyncio.create_task(_do_restart())
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import weakref
|
import weakref
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
@@ -57,13 +58,30 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
|||||||
return args[0] if args and isinstance(args[0], dict) else None
|
return args[0] if args and isinstance(args[0], dict) else None
|
||||||
return args if isinstance(args, dict) else None
|
return args if isinstance(args, dict) else None
|
||||||
|
|
||||||
|
_TOOL_CHOICE_ERROR_MARKERS = (
|
||||||
|
"tool_choice",
|
||||||
|
"toolchoice",
|
||||||
|
"does not support",
|
||||||
|
'should be ["none", "auto"]',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_choice_unsupported(content: str | None) -> bool:
|
||||||
|
"""Detect provider errors caused by forced tool_choice being unsupported."""
|
||||||
|
text = (content or "").lower()
|
||||||
|
return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS)
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
|
|
||||||
|
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.memory_dir = ensure_dir(workspace / "memory")
|
self.memory_dir = ensure_dir(workspace / "memory")
|
||||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||||
self.history_file = self.memory_dir / "HISTORY.md"
|
self.history_file = self.memory_dir / "HISTORY.md"
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
|
||||||
def read_long_term(self) -> str:
|
def read_long_term(self) -> str:
|
||||||
if self.memory_file.exists():
|
if self.memory_file.exists():
|
||||||
@@ -112,38 +130,93 @@ class MemoryStore:
|
|||||||
## Conversation to Process
|
## Conversation to Process
|
||||||
{self._format_messages(messages)}"""
|
{self._format_messages(messages)}"""
|
||||||
|
|
||||||
try:
|
chat_messages = [
|
||||||
response = await provider.chat_with_retry(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
],
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
forced = {"type": "function", "function": {"name": "save_memory"}}
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=chat_messages,
|
||||||
tools=_SAVE_MEMORY_TOOL,
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
model=model,
|
model=model,
|
||||||
tool_choice="required",
|
tool_choice=forced,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.finish_reason == "error" and _is_tool_choice_unsupported(
|
||||||
|
response.content
|
||||||
|
):
|
||||||
|
logger.warning("Forced tool_choice unsupported, retrying with auto")
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=chat_messages,
|
||||||
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
|
model=model,
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
logger.warning(
|
||||||
return False
|
"Memory consolidation: LLM did not call save_memory "
|
||||||
|
"(finish_reason={}, content_len={}, content_preview={})",
|
||||||
|
response.finish_reason,
|
||||||
|
len(response.content or ""),
|
||||||
|
(response.content or "")[:200],
|
||||||
|
)
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
||||||
if args is None:
|
if args is None:
|
||||||
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
||||||
return False
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if "history_entry" not in args or "memory_update" not in args:
|
||||||
self.append_history(_ensure_text(entry))
|
logger.warning("Memory consolidation: save_memory payload missing required fields")
|
||||||
if update := args.get("memory_update"):
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
entry = args["history_entry"]
|
||||||
|
update = args["memory_update"]
|
||||||
|
|
||||||
|
if entry is None or update is None:
|
||||||
|
logger.warning("Memory consolidation: save_memory payload contains null required fields")
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
entry = _ensure_text(entry).strip()
|
||||||
|
if not entry:
|
||||||
|
logger.warning("Memory consolidation: history_entry is empty after normalization")
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
self.append_history(entry)
|
||||||
update = _ensure_text(update)
|
update = _ensure_text(update)
|
||||||
if update != current_memory:
|
if update != current_memory:
|
||||||
self.write_long_term(update)
|
self.write_long_term(update)
|
||||||
|
|
||||||
|
self._consecutive_failures = 0
|
||||||
logger.info("Memory consolidation done for {} messages", len(messages))
|
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Memory consolidation failed")
|
logger.exception("Memory consolidation failed")
|
||||||
|
return self._fail_or_raw_archive(messages)
|
||||||
|
|
||||||
|
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
|
||||||
|
"""Increment failure count; after threshold, raw-archive messages and return True."""
|
||||||
|
self._consecutive_failures += 1
|
||||||
|
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
|
||||||
return False
|
return False
|
||||||
|
self._raw_archive(messages)
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _raw_archive(self, messages: list[dict]) -> None:
|
||||||
|
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
|
||||||
|
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||||
|
self.append_history(
|
||||||
|
f"[{ts}] [RAW] {len(messages)} messages\n"
|
||||||
|
f"{self._format_messages(messages)}"
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Memory consolidation degraded: raw-archived {} messages", len(messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryConsolidator:
|
class MemoryConsolidator:
|
||||||
|
|||||||
@@ -352,6 +352,27 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Feishu bot stopped")
|
logger.info("Feishu bot stopped")
|
||||||
|
|
||||||
|
def _is_bot_mentioned(self, message: Any) -> bool:
|
||||||
|
"""Check if the bot is @mentioned in the message."""
|
||||||
|
raw_content = message.content or ""
|
||||||
|
if "@_all" in raw_content:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for mention in getattr(message, "mentions", None) or []:
|
||||||
|
mid = getattr(mention, "id", None)
|
||||||
|
if not mid:
|
||||||
|
continue
|
||||||
|
# Bot mentions have no user_id (None or "") but a valid open_id
|
||||||
|
if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_group_message_for_bot(self, message: Any) -> bool:
|
||||||
|
"""Allow group messages when policy is open or bot is @mentioned."""
|
||||||
|
if self.config.group_policy == "open":
|
||||||
|
return True
|
||||||
|
return self._is_bot_mentioned(message)
|
||||||
|
|
||||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||||
@@ -893,6 +914,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
chat_type = message.chat_type
|
chat_type = message.chat_type
|
||||||
msg_type = message.message_type
|
msg_type = message.message_type
|
||||||
|
|
||||||
|
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
||||||
|
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||||
|
return
|
||||||
|
|
||||||
# Add reaction
|
# Add reaction
|
||||||
await self._add_reaction(message_id, self.config.react_emoji)
|
await self._add_reaction(message_id, self.config.react_emoji)
|
||||||
|
|
||||||
|
|||||||
@@ -114,16 +114,16 @@ class QQChannel(BaseChannel):
|
|||||||
if msg_type == "group":
|
if msg_type == "group":
|
||||||
await self._client.api.post_group_message(
|
await self._client.api.post_group_message(
|
||||||
group_openid=msg.chat_id,
|
group_openid=msg.chat_id,
|
||||||
msg_type=2,
|
msg_type=0,
|
||||||
markdown={"content": msg.content},
|
content=msg.content,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
msg_seq=self._msg_seq,
|
msg_seq=self._msg_seq,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._client.api.post_c2c_message(
|
await self._client.api.post_c2c_message(
|
||||||
openid=msg.chat_id,
|
openid=msg.chat_id,
|
||||||
msg_type=2,
|
msg_type=0,
|
||||||
markdown={"content": msg.content},
|
content=msg.content,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
msg_seq=self._msg_seq,
|
msg_seq=self._msg_seq,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from nanobot.config.schema import TelegramConfig
|
|||||||
from nanobot.utils.helpers import split_message
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||||
|
TELEGRAM_REPLY_CONTEXT_MAX_LEN = TELEGRAM_MAX_MESSAGE_LEN # Max length for reply context in user message
|
||||||
|
|
||||||
|
|
||||||
def _strip_md(s: str) -> str:
|
def _strip_md(s: str) -> str:
|
||||||
@@ -453,6 +454,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_message_metadata(message, user) -> dict:
|
def _build_message_metadata(message, user) -> dict:
|
||||||
"""Build common Telegram inbound metadata payload."""
|
"""Build common Telegram inbound metadata payload."""
|
||||||
|
reply_to = getattr(message, "reply_to_message", None)
|
||||||
return {
|
return {
|
||||||
"message_id": message.message_id,
|
"message_id": message.message_id,
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
@@ -461,8 +463,73 @@ class TelegramChannel(BaseChannel):
|
|||||||
"is_group": message.chat.type != "private",
|
"is_group": message.chat.type != "private",
|
||||||
"message_thread_id": getattr(message, "message_thread_id", None),
|
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||||
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||||
|
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_reply_context(message) -> str | None:
|
||||||
|
"""Extract text from the message being replied to, if any."""
|
||||||
|
reply = getattr(message, "reply_to_message", None)
|
||||||
|
if not reply:
|
||||||
|
return None
|
||||||
|
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
|
||||||
|
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
|
||||||
|
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
|
||||||
|
return f"[Reply to: {text}]" if text else None
|
||||||
|
|
||||||
|
async def _download_message_media(
|
||||||
|
self, msg, *, add_failure_content: bool = False
|
||||||
|
) -> tuple[list[str], list[str]]:
|
||||||
|
"""Download media from a message (current or reply). Returns (media_paths, content_parts)."""
|
||||||
|
media_file = None
|
||||||
|
media_type = None
|
||||||
|
if getattr(msg, "photo", None):
|
||||||
|
media_file = msg.photo[-1]
|
||||||
|
media_type = "image"
|
||||||
|
elif getattr(msg, "voice", None):
|
||||||
|
media_file = msg.voice
|
||||||
|
media_type = "voice"
|
||||||
|
elif getattr(msg, "audio", None):
|
||||||
|
media_file = msg.audio
|
||||||
|
media_type = "audio"
|
||||||
|
elif getattr(msg, "document", None):
|
||||||
|
media_file = msg.document
|
||||||
|
media_type = "file"
|
||||||
|
elif getattr(msg, "video", None):
|
||||||
|
media_file = msg.video
|
||||||
|
media_type = "video"
|
||||||
|
elif getattr(msg, "video_note", None):
|
||||||
|
media_file = msg.video_note
|
||||||
|
media_type = "video"
|
||||||
|
elif getattr(msg, "animation", None):
|
||||||
|
media_file = msg.animation
|
||||||
|
media_type = "animation"
|
||||||
|
if not media_file or not self._app:
|
||||||
|
return [], []
|
||||||
|
try:
|
||||||
|
file = await self._app.bot.get_file(media_file.file_id)
|
||||||
|
ext = self._get_extension(
|
||||||
|
media_type,
|
||||||
|
getattr(media_file, "mime_type", None),
|
||||||
|
getattr(media_file, "file_name", None),
|
||||||
|
)
|
||||||
|
media_dir = get_media_dir("telegram")
|
||||||
|
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||||
|
await file.download_to_drive(str(file_path))
|
||||||
|
path_str = str(file_path)
|
||||||
|
if media_type in ("voice", "audio"):
|
||||||
|
transcription = await self.transcribe_audio(file_path)
|
||||||
|
if transcription:
|
||||||
|
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||||
|
return [path_str], [f"[transcription: {transcription}]"]
|
||||||
|
return [path_str], [f"[{media_type}: {path_str}]"]
|
||||||
|
return [path_str], [f"[{media_type}: {path_str}]"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to download message media: {}", e)
|
||||||
|
if add_failure_content:
|
||||||
|
return [], [f"[{media_type}: download failed]"]
|
||||||
|
return [], []
|
||||||
|
|
||||||
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
|
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
|
||||||
"""Load bot identity once and reuse it for mention/reply checks."""
|
"""Load bot identity once and reuse it for mention/reply checks."""
|
||||||
if self._bot_user_id is not None or self._bot_username is not None:
|
if self._bot_user_id is not None or self._bot_username is not None:
|
||||||
@@ -547,7 +614,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=self._sender_id(user),
|
sender_id=self._sender_id(user),
|
||||||
chat_id=str(message.chat_id),
|
chat_id=str(message.chat_id),
|
||||||
content=message.text,
|
content=message.text or "",
|
||||||
metadata=self._build_message_metadata(message, user),
|
metadata=self._build_message_metadata(message, user),
|
||||||
session_key=self._derive_topic_session_key(message),
|
session_key=self._derive_topic_session_key(message),
|
||||||
)
|
)
|
||||||
@@ -579,54 +646,26 @@ class TelegramChannel(BaseChannel):
|
|||||||
if message.caption:
|
if message.caption:
|
||||||
content_parts.append(message.caption)
|
content_parts.append(message.caption)
|
||||||
|
|
||||||
# Handle media files
|
# Download current message media
|
||||||
media_file = None
|
current_media_paths, current_media_parts = await self._download_message_media(
|
||||||
media_type = None
|
message, add_failure_content=True
|
||||||
|
|
||||||
if message.photo:
|
|
||||||
media_file = message.photo[-1] # Largest photo
|
|
||||||
media_type = "image"
|
|
||||||
elif message.voice:
|
|
||||||
media_file = message.voice
|
|
||||||
media_type = "voice"
|
|
||||||
elif message.audio:
|
|
||||||
media_file = message.audio
|
|
||||||
media_type = "audio"
|
|
||||||
elif message.document:
|
|
||||||
media_file = message.document
|
|
||||||
media_type = "file"
|
|
||||||
|
|
||||||
# Download media if present
|
|
||||||
if media_file and self._app:
|
|
||||||
try:
|
|
||||||
file = await self._app.bot.get_file(media_file.file_id)
|
|
||||||
ext = self._get_extension(
|
|
||||||
media_type,
|
|
||||||
getattr(media_file, 'mime_type', None),
|
|
||||||
getattr(media_file, 'file_name', None),
|
|
||||||
)
|
)
|
||||||
media_dir = get_media_dir("telegram")
|
media_paths.extend(current_media_paths)
|
||||||
|
content_parts.extend(current_media_parts)
|
||||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
if current_media_paths:
|
||||||
await file.download_to_drive(str(file_path))
|
logger.debug("Downloaded message media to {}", current_media_paths[0])
|
||||||
|
|
||||||
media_paths.append(str(file_path))
|
|
||||||
|
|
||||||
if media_type in ("voice", "audio"):
|
|
||||||
transcription = await self.transcribe_audio(file_path)
|
|
||||||
if transcription:
|
|
||||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
|
||||||
content_parts.append(f"[transcription: {transcription}]")
|
|
||||||
else:
|
|
||||||
content_parts.append(f"[{media_type}: {file_path}]")
|
|
||||||
else:
|
|
||||||
content_parts.append(f"[{media_type}: {file_path}]")
|
|
||||||
|
|
||||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to download media: {}", e)
|
|
||||||
content_parts.append(f"[{media_type}: download failed]")
|
|
||||||
|
|
||||||
|
# Reply context: text and/or media from the replied-to message
|
||||||
|
reply = getattr(message, "reply_to_message", None)
|
||||||
|
if reply is not None:
|
||||||
|
reply_ctx = self._extract_reply_context(message)
|
||||||
|
reply_media, reply_media_parts = await self._download_message_media(reply)
|
||||||
|
if reply_media:
|
||||||
|
media_paths = reply_media + media_paths
|
||||||
|
logger.debug("Attached replied-to media: {}", reply_media[0])
|
||||||
|
tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
|
||||||
|
if tag:
|
||||||
|
content_parts.insert(0, tag)
|
||||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||||
|
|
||||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||||
|
|||||||
@@ -19,10 +19,12 @@ if sys.platform == "win32":
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from prompt_toolkit import print_formatted_text
|
||||||
from prompt_toolkit import PromptSession
|
from prompt_toolkit import PromptSession
|
||||||
from prompt_toolkit.formatted_text import HTML
|
from prompt_toolkit.formatted_text import ANSI, HTML
|
||||||
from prompt_toolkit.history import FileHistory
|
from prompt_toolkit.history import FileHistory
|
||||||
from prompt_toolkit.patch_stdout import patch_stdout
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
|
from prompt_toolkit.application import run_in_terminal
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
@@ -111,8 +113,25 @@ def _init_prompt_session() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_console() -> Console:
|
||||||
|
return Console(file=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_interactive_ansi(render_fn) -> str:
|
||||||
|
"""Render Rich output to ANSI so prompt_toolkit can print it safely."""
|
||||||
|
ansi_console = Console(
|
||||||
|
force_terminal=True,
|
||||||
|
color_system=console.color_system or "standard",
|
||||||
|
width=console.width,
|
||||||
|
)
|
||||||
|
with ansi_console.capture() as capture:
|
||||||
|
render_fn(ansi_console)
|
||||||
|
return capture.get()
|
||||||
|
|
||||||
|
|
||||||
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||||
"""Render assistant response with consistent terminal styling."""
|
"""Render assistant response with consistent terminal styling."""
|
||||||
|
console = _make_console()
|
||||||
content = response or ""
|
content = response or ""
|
||||||
body = Markdown(content) if render_markdown else Text(content)
|
body = Markdown(content) if render_markdown else Text(content)
|
||||||
console.print()
|
console.print()
|
||||||
@@ -121,6 +140,34 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
|
|||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
|
async def _print_interactive_line(text: str) -> None:
|
||||||
|
"""Print async interactive updates with prompt_toolkit-safe Rich styling."""
|
||||||
|
def _write() -> None:
|
||||||
|
ansi = _render_interactive_ansi(
|
||||||
|
lambda c: c.print(f" [dim]↳ {text}[/dim]")
|
||||||
|
)
|
||||||
|
print_formatted_text(ANSI(ansi), end="")
|
||||||
|
|
||||||
|
await run_in_terminal(_write)
|
||||||
|
|
||||||
|
|
||||||
|
async def _print_interactive_response(response: str, render_markdown: bool) -> None:
|
||||||
|
"""Print async interactive replies with prompt_toolkit-safe Rich styling."""
|
||||||
|
def _write() -> None:
|
||||||
|
content = response or ""
|
||||||
|
ansi = _render_interactive_ansi(
|
||||||
|
lambda c: (
|
||||||
|
c.print(),
|
||||||
|
c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
|
||||||
|
c.print(Markdown(content) if render_markdown else Text(content)),
|
||||||
|
c.print(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print_formatted_text(ANSI(ansi), end="")
|
||||||
|
|
||||||
|
await run_in_terminal(_write)
|
||||||
|
|
||||||
|
|
||||||
def _is_exit_command(command: str) -> bool:
|
def _is_exit_command(command: str) -> bool:
|
||||||
"""Return True when input should end interactive chat."""
|
"""Return True when input should end interactive chat."""
|
||||||
return command.lower() in EXIT_COMMANDS
|
return command.lower() in EXIT_COMMANDS
|
||||||
@@ -610,14 +657,15 @@ def agent(
|
|||||||
elif ch and not is_tool_hint and not ch.send_progress:
|
elif ch and not is_tool_hint and not ch.send_progress:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
console.print(f" [dim]↳ {msg.content}[/dim]")
|
await _print_interactive_line(msg.content)
|
||||||
|
|
||||||
elif not turn_done.is_set():
|
elif not turn_done.is_set():
|
||||||
if msg.content:
|
if msg.content:
|
||||||
turn_response.append(msg.content)
|
turn_response.append(msg.content)
|
||||||
turn_done.set()
|
turn_done.set()
|
||||||
elif msg.content:
|
elif msg.content:
|
||||||
console.print()
|
await _print_interactive_response(msg.content, render_markdown=markdown)
|
||||||
_print_agent_response(msg.content, render_markdown=markdown)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class FeishuConfig(Base):
|
|||||||
react_emoji: str = (
|
react_emoji: str = (
|
||||||
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
||||||
)
|
)
|
||||||
|
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned, "open" responds to all
|
||||||
|
|
||||||
|
|
||||||
class DingTalkConfig(Base):
|
class DingTalkConfig(Base):
|
||||||
@@ -275,15 +276,18 @@ class ProvidersConfig(Base):
|
|||||||
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
|
dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
|
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
||||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
|
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
|
||||||
|
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
|
||||||
|
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||||
|
|
||||||
@@ -397,12 +401,21 @@ class Config(BaseSettings):
|
|||||||
|
|
||||||
# Fallback: configured local providers can route models without
|
# Fallback: configured local providers can route models without
|
||||||
# provider-specific keywords (for example plain "llama3.2" on Ollama).
|
# provider-specific keywords (for example plain "llama3.2" on Ollama).
|
||||||
|
# Prefer providers whose detect_by_base_keyword matches the configured api_base
|
||||||
|
# (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
|
||||||
|
local_fallback: tuple[ProviderConfig, str] | None = None
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
if not spec.is_local:
|
if not spec.is_local:
|
||||||
continue
|
continue
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and p.api_base:
|
if not (p and p.api_base):
|
||||||
|
continue
|
||||||
|
if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
if local_fallback is None:
|
||||||
|
local_fallback = (p, spec.name)
|
||||||
|
if local_fallback:
|
||||||
|
return local_fallback
|
||||||
|
|
||||||
# Fallback: gateways first, then others (follows registry order)
|
# Fallback: gateways first, then others (follows registry order)
|
||||||
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
||||||
|
|||||||
@@ -145,7 +145,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
|
||||||
|
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="volcengine",
|
name="volcengine",
|
||||||
keywords=("volcengine", "volces", "ark"),
|
keywords=("volcengine", "volces", "ark"),
|
||||||
@@ -162,6 +163,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
|
||||||
|
ProviderSpec(
|
||||||
|
name="volcengine_coding_plan",
|
||||||
|
keywords=("volcengine-plan",),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="VolcEngine Coding Plan",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="",
|
||||||
|
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||||
|
strip_model_prefix=True,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
|
# BytePlus: VolcEngine international, pay-per-use models
|
||||||
|
ProviderSpec(
|
||||||
|
name="byteplus",
|
||||||
|
keywords=("byteplus",),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="BytePlus",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="bytepluses",
|
||||||
|
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
|
||||||
|
strip_model_prefix=True,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
|
# BytePlus Coding Plan: same key as byteplus
|
||||||
|
ProviderSpec(
|
||||||
|
name="byteplus_coding_plan",
|
||||||
|
keywords=("byteplus-plan",),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="BytePlus Coding Plan",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="",
|
||||||
|
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
|
||||||
|
strip_model_prefix=True,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
|
|
||||||
# === Standard providers (matched by model-name keywords) ===============
|
# === Standard providers (matched by model-name keywords) ===============
|
||||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
|
|||||||
@@ -75,13 +75,6 @@ build-backend = "hatchling.build"
|
|||||||
[tool.hatch.metadata]
|
[tool.hatch.metadata]
|
||||||
allow-direct-references = true
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
|
||||||
packages = ["nanobot"]
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel.sources]
|
|
||||||
"nanobot" = "nanobot"
|
|
||||||
|
|
||||||
# Include non-Python files in skills and templates
|
|
||||||
[tool.hatch.build]
|
[tool.hatch.build]
|
||||||
include = [
|
include = [
|
||||||
"nanobot/**/*.py",
|
"nanobot/**/*.py",
|
||||||
@@ -90,6 +83,15 @@ include = [
|
|||||||
"nanobot/skills/**/*.sh",
|
"nanobot/skills/**/*.sh",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["nanobot"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.sources]
|
||||||
|
"nanobot" = "nanobot"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.force-include]
|
||||||
|
"bridge" = "nanobot/bridge"
|
||||||
|
|
||||||
[tool.hatch.build.targets.sdist]
|
[tool.hatch.build.targets.sdist]
|
||||||
include = [
|
include = [
|
||||||
"nanobot/",
|
"nanobot/",
|
||||||
@@ -98,9 +100,6 @@ include = [
|
|||||||
"LICENSE",
|
"LICENSE",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel.force-include]
|
|
||||||
"bridge" = "nanobot/bridge"
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
|
|||||||
@@ -150,6 +150,35 @@ def test_config_auto_detects_ollama_from_local_api_base():
|
|||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
|
"providers": {
|
||||||
|
"vllm": {"apiBase": "http://localhost:8000"},
|
||||||
|
"ollama": {"apiBase": "http://localhost:11434"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "ollama"
|
||||||
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||||
|
"providers": {
|
||||||
|
"vllm": {"apiBase": "http://localhost:8000"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "vllm"
|
||||||
|
assert config.get_api_base() == "http://localhost:8000"
|
||||||
|
|
||||||
|
|
||||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||||
|
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
|
|
||||||
# Simulate arguments being a JSON string (not yet parsed)
|
|
||||||
response = LLMResponse(
|
response = LLMResponse(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -170,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
|
|
||||||
# Simulate arguments being a list containing a dict
|
|
||||||
response = LLMResponse(
|
response = LLMResponse(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -242,6 +240,94 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Do not persist partial results when required fields are missing."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={"memory_update": "# Memory\nOnly memory update"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Do not append history if memory_update is missing."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={"history_entry": "[2026-01-01] Partial output."},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Null required fields should be rejected before persistence."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry=None,
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Empty history entries should be rejected to avoid blank archival records."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry=" ",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
@@ -288,3 +374,105 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
assert "temperature" not in kwargs
|
assert "temperature" not in kwargs
|
||||||
assert "max_tokens" not in kwargs
|
assert "max_tokens" not in kwargs
|
||||||
assert "reasoning_effort" not in kwargs
|
assert "reasoning_effort" not in kwargs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
|
||||||
|
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
error_resp = LLMResponse(
|
||||||
|
content="Error calling LLM: litellm.BadRequestError: "
|
||||||
|
"The tool_choice parameter does not support being set to required or object",
|
||||||
|
finish_reason="error",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
ok_resp = _make_tool_response(
|
||||||
|
history_entry="[2026-01-01] Fallback worked.",
|
||||||
|
memory_update="# Memory\nFallback OK.",
|
||||||
|
)
|
||||||
|
|
||||||
|
call_log: list[dict] = []
|
||||||
|
|
||||||
|
async def _tracking_chat(**kwargs):
|
||||||
|
call_log.append(kwargs)
|
||||||
|
return error_resp if len(call_log) == 1 else ok_resp
|
||||||
|
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert len(call_log) == 2
|
||||||
|
assert isinstance(call_log[0]["tool_choice"], dict)
|
||||||
|
assert call_log[1]["tool_choice"] == "auto"
|
||||||
|
assert "Fallback worked." in store.history_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
|
||||||
|
"""Forced rejected, auto retry also produces no tool call -> return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
error_resp = LLMResponse(
|
||||||
|
content="Error: tool_choice must be none or auto",
|
||||||
|
finish_reason="error",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
no_tool_resp = LLMResponse(
|
||||||
|
content="Here is a summary.",
|
||||||
|
finish_reason="stop",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
|
||||||
|
"""After 3 consecutive failures, raw-archive messages and return True."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||||
|
messages = _make_messages(message_count=10)
|
||||||
|
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert await store.consolidate(messages, provider, "m") is True
|
||||||
|
|
||||||
|
assert store.history_file.exists()
|
||||||
|
content = store.history_file.read_text()
|
||||||
|
assert "[RAW]" in content
|
||||||
|
assert "10 messages" in content
|
||||||
|
assert "msg0" in content
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
|
||||||
|
"""A successful consolidation resets the failure counter."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
|
||||||
|
ok_resp = _make_tool_response(
|
||||||
|
history_entry="[2026-01-01] OK.",
|
||||||
|
memory_update="# Memory\nOK.",
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=10)
|
||||||
|
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert store._consecutive_failures == 2
|
||||||
|
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
|
||||||
|
assert await store.consolidate(messages, provider, "m") is True
|
||||||
|
assert store._consecutive_failures == 0
|
||||||
|
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=no_tool)
|
||||||
|
assert await store.consolidate(messages, provider, "m") is False
|
||||||
|
assert store._consecutive_failures == 1
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
|
||||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
channel._client = _FakeClient()
|
channel._client = _FakeClient()
|
||||||
channel._chat_type_cache["group123"] = "group"
|
channel._chat_type_cache["group123"] = "group"
|
||||||
@@ -60,7 +60,37 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
|||||||
|
|
||||||
assert len(channel._client.api.group_calls) == 1
|
assert len(channel._client.api.group_calls) == 1
|
||||||
call = channel._client.api.group_calls[0]
|
call = channel._client.api.group_calls[0]
|
||||||
assert call["group_openid"] == "group123"
|
assert call == {
|
||||||
assert call["msg_id"] == "msg1"
|
"group_openid": "group123",
|
||||||
assert call["msg_seq"] == 2
|
"msg_type": 0,
|
||||||
|
"content": "hello",
|
||||||
|
"msg_id": "msg1",
|
||||||
|
"msg_seq": 2,
|
||||||
|
}
|
||||||
assert not channel._client.api.c2c_calls
|
assert not channel._client.api.c2c_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="user123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "msg1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(channel._client.api.c2c_calls) == 1
|
||||||
|
call = channel._client.api.c2c_calls[0]
|
||||||
|
assert call == {
|
||||||
|
"openid": "user123",
|
||||||
|
"msg_type": 0,
|
||||||
|
"content": "hello",
|
||||||
|
"msg_id": "msg1",
|
||||||
|
"msg_seq": 2,
|
||||||
|
}
|
||||||
|
assert not channel._client.api.group_calls
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.telegram import TelegramChannel
|
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
|
||||||
from nanobot.config.schema import TelegramConfig
|
from nanobot.config.schema import TelegramConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -42,6 +45,12 @@ class _FakeBot:
|
|||||||
async def send_chat_action(self, **kwargs) -> None:
|
async def send_chat_action(self, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def get_file(self, file_id: str):
|
||||||
|
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
|
||||||
|
async def _fake_download(path) -> None:
|
||||||
|
pass
|
||||||
|
return SimpleNamespace(download_to_drive=_fake_download)
|
||||||
|
|
||||||
|
|
||||||
class _FakeApp:
|
class _FakeApp:
|
||||||
def __init__(self, on_start_polling) -> None:
|
def __init__(self, on_start_polling) -> None:
|
||||||
@@ -336,3 +345,255 @@ async def test_group_policy_open_accepts_plain_group_message() -> None:
|
|||||||
|
|
||||||
assert len(handled) == 1
|
assert len(handled) == 1
|
||||||
assert channel._app.bot.get_me_calls == 0
|
assert channel._app.bot.get_me_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_reply_context_no_reply() -> None:
|
||||||
|
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||||
|
message = SimpleNamespace(reply_to_message=None)
|
||||||
|
assert TelegramChannel._extract_reply_context(message) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_reply_context_with_text() -> None:
|
||||||
|
"""When reply has text, return prefixed string."""
|
||||||
|
reply = SimpleNamespace(text="Hello world", caption=None)
|
||||||
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
|
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_reply_context_with_caption_only() -> None:
|
||||||
|
"""When reply has only caption (no text), caption is used."""
|
||||||
|
reply = SimpleNamespace(text=None, caption="Photo caption")
|
||||||
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
|
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_reply_context_truncation() -> None:
|
||||||
|
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||||
|
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||||
|
reply = SimpleNamespace(text=long_text, caption=None)
|
||||||
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
|
result = TelegramChannel._extract_reply_context(message)
|
||||||
|
assert result is not None
|
||||||
|
assert result.startswith("[Reply to: ")
|
||||||
|
assert result.endswith("...]")
|
||||||
|
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_reply_context_no_text_returns_none() -> None:
|
||||||
|
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||||
|
reply = SimpleNamespace(text=None, caption=None)
|
||||||
|
message = SimpleNamespace(reply_to_message=reply)
|
||||||
|
assert TelegramChannel._extract_reply_context(message) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_includes_reply_context() -> None:
|
||||||
|
"""When user replies to a message, content passed to bus starts with reply context."""
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
handled = []
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
|
||||||
|
update = _make_telegram_update(text="translate this", reply_to_message=reply)
|
||||||
|
await channel._on_message(update, None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["content"].startswith("[Reply to: Hello]")
|
||||||
|
assert "translate this" in handled[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_message_media_returns_path_when_download_succeeds(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
) -> None:
|
||||||
|
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
|
||||||
|
media_dir = tmp_path / "media" / "telegram"
|
||||||
|
media_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.telegram.get_media_dir",
|
||||||
|
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
channel._app.bot.get_file = AsyncMock(
|
||||||
|
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = SimpleNamespace(
|
||||||
|
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
|
||||||
|
voice=None,
|
||||||
|
audio=None,
|
||||||
|
document=None,
|
||||||
|
video=None,
|
||||||
|
video_note=None,
|
||||||
|
animation=None,
|
||||||
|
)
|
||||||
|
paths, parts = await channel._download_message_media(msg)
|
||||||
|
assert len(paths) == 1
|
||||||
|
assert len(parts) == 1
|
||||||
|
assert "fid123" in paths[0]
|
||||||
|
assert "[image:" in parts[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
|
||||||
|
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
|
||||||
|
media_dir = tmp_path / "media" / "telegram"
|
||||||
|
media_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.telegram.get_media_dir",
|
||||||
|
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
app = _FakeApp(lambda: None)
|
||||||
|
app.bot.get_file = AsyncMock(
|
||||||
|
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||||
|
)
|
||||||
|
channel._app = app
|
||||||
|
handled = []
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
reply_with_photo = SimpleNamespace(
|
||||||
|
text=None,
|
||||||
|
caption=None,
|
||||||
|
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
|
||||||
|
document=None,
|
||||||
|
voice=None,
|
||||||
|
audio=None,
|
||||||
|
video=None,
|
||||||
|
video_note=None,
|
||||||
|
animation=None,
|
||||||
|
)
|
||||||
|
update = _make_telegram_update(
|
||||||
|
text="what is the image?",
|
||||||
|
reply_to_message=reply_with_photo,
|
||||||
|
)
|
||||||
|
await channel._on_message(update, None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["content"].startswith("[Reply to: [image:")
|
||||||
|
assert "what is the image?" in handled[0]["content"]
|
||||||
|
assert len(handled[0]["media"]) == 1
|
||||||
|
assert "reply_photo_fid" in handled[0]["media"][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
|
||||||
|
"""When reply has media but download fails, no media attached and no reply tag."""
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
channel._app.bot.get_file = None
|
||||||
|
handled = []
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
reply_with_photo = SimpleNamespace(
|
||||||
|
text=None,
|
||||||
|
caption=None,
|
||||||
|
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
|
||||||
|
document=None,
|
||||||
|
voice=None,
|
||||||
|
audio=None,
|
||||||
|
video=None,
|
||||||
|
video_note=None,
|
||||||
|
animation=None,
|
||||||
|
)
|
||||||
|
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
|
||||||
|
await channel._on_message(update, None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert "what is this?" in handled[0]["content"]
|
||||||
|
assert handled[0]["media"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
|
||||||
|
"""When replying to a message with caption + photo, both text context and media are included."""
|
||||||
|
media_dir = tmp_path / "media" / "telegram"
|
||||||
|
media_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.telegram.get_media_dir",
|
||||||
|
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
app = _FakeApp(lambda: None)
|
||||||
|
app.bot.get_file = AsyncMock(
|
||||||
|
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||||
|
)
|
||||||
|
channel._app = app
|
||||||
|
handled = []
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
reply_with_caption_and_photo = SimpleNamespace(
|
||||||
|
text=None,
|
||||||
|
caption="A cute cat",
|
||||||
|
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
|
||||||
|
document=None,
|
||||||
|
voice=None,
|
||||||
|
audio=None,
|
||||||
|
video=None,
|
||||||
|
video_note=None,
|
||||||
|
animation=None,
|
||||||
|
)
|
||||||
|
update = _make_telegram_update(
|
||||||
|
text="what breed is this?",
|
||||||
|
reply_to_message=reply_with_caption_and_photo,
|
||||||
|
)
|
||||||
|
await channel._on_message(update, None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert "[Reply to: A cute cat]" in handled[0]["content"]
|
||||||
|
assert "what breed is this?" in handled[0]["content"]
|
||||||
|
assert len(handled[0]["media"]) == 1
|
||||||
|
assert "cat_fid" in handled[0]["media"][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||||
|
"""Slash commands forwarded via _forward_command must not include reply context."""
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
handled = []
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
|
||||||
|
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
|
||||||
|
update = _make_telegram_update(text="/new", reply_to_message=reply)
|
||||||
|
await channel._forward_command(update, None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["content"] == "/new"
|
||||||
|
|||||||
Reference in New Issue
Block a user