Merge pull request #1482 from contributors/telegram-topic-support

This commit is contained in:
Re-bin
2026-03-07 15:33:24 +00:00
4 changed files with 137 additions and 26 deletions

View File

@@ -96,7 +96,7 @@ class MessageTool(Tool):
media=media or [], media=media or [],
metadata={ metadata={
"message_id": message_id, "message_id": message_id,
} },
) )
try: try:

View File

@@ -177,6 +177,7 @@ class TelegramChannel(BaseChannel):
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {} self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {} self._media_group_tasks: dict[str, asyncio.Task] = {}
self._message_threads: dict[tuple[str, int], int] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
@@ -286,10 +287,16 @@ class TelegramChannel(BaseChannel):
except ValueError: except ValueError:
logger.error("Invalid chat_id: {}", msg.chat_id) logger.error("Invalid chat_id: {}", msg.chat_id)
return return
reply_to_message_id = msg.metadata.get("message_id")
message_thread_id = msg.metadata.get("message_thread_id")
if message_thread_id is None and reply_to_message_id is not None:
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
thread_kwargs = {}
if message_thread_id is not None:
thread_kwargs["message_thread_id"] = message_thread_id
reply_params = None reply_params = None
if self.config.reply_to_message: if self.config.reply_to_message:
reply_to_message_id = msg.metadata.get("message_id")
if reply_to_message_id: if reply_to_message_id:
reply_params = ReplyParameters( reply_params = ReplyParameters(
message_id=reply_to_message_id, message_id=reply_to_message_id,
@@ -310,7 +317,8 @@ class TelegramChannel(BaseChannel):
await sender( await sender(
chat_id=chat_id, chat_id=chat_id,
**{param: f}, **{param: f},
reply_parameters=reply_params reply_parameters=reply_params,
**thread_kwargs,
) )
except Exception as e: except Exception as e:
filename = media_path.rsplit("/", 1)[-1] filename = media_path.rsplit("/", 1)[-1]
@@ -318,7 +326,8 @@ class TelegramChannel(BaseChannel):
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=f"[Failed to send: {filename}]", text=f"[Failed to send: {filename}]",
reply_parameters=reply_params reply_parameters=reply_params,
**thread_kwargs,
) )
# Send text content # Send text content
@@ -328,28 +337,44 @@ class TelegramChannel(BaseChannel):
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN): for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
# Final response: simulate streaming via draft, then persist # Final response: simulate streaming via draft, then persist
if not is_progress: if not is_progress:
await self._send_with_streaming(chat_id, chunk, reply_params) await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
else: else:
await self._send_text(chat_id, chunk, reply_params) await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None: async def _send_text(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Send a plain text message with HTML fallback.""" """Send a plain text message with HTML fallback."""
try: try:
html = _markdown_to_telegram_html(text) html = _markdown_to_telegram_html(text)
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, text=html, parse_mode="HTML", chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params, reply_parameters=reply_params,
**(thread_kwargs or {}),
) )
except Exception as e: except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e) logger.warning("HTML parse failed, falling back to plain text: {}", e)
try: try:
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, text=text, reply_parameters=reply_params, chat_id=chat_id,
text=text,
reply_parameters=reply_params,
**(thread_kwargs or {}),
) )
except Exception as e2: except Exception as e2:
logger.error("Error sending Telegram message: {}", e2) logger.error("Error sending Telegram message: {}", e2)
async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None: async def _send_with_streaming(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Simulate streaming via send_message_draft, then persist with send_message.""" """Simulate streaming via send_message_draft, then persist with send_message."""
draft_id = int(time.time() * 1000) % (2**31) draft_id = int(time.time() * 1000) % (2**31)
try: try:
@@ -365,7 +390,7 @@ class TelegramChannel(BaseChannel):
await asyncio.sleep(0.15) await asyncio.sleep(0.15)
except Exception: except Exception:
pass pass
await self._send_text(chat_id, text, reply_params) await self._send_text(chat_id, text, reply_params, thread_kwargs)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
@@ -396,14 +421,50 @@ class TelegramChannel(BaseChannel):
sid = str(user.id) sid = str(user.id)
return f"{sid}|{user.username}" if user.username else sid return f"{sid}|{user.username}" if user.username else sid
@staticmethod
def _derive_topic_session_key(message) -> str | None:
"""Derive topic-scoped session key for non-private Telegram chats."""
message_thread_id = getattr(message, "message_thread_id", None)
if message.chat.type == "private" or message_thread_id is None:
return None
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
@staticmethod
def _build_message_metadata(message, user) -> dict:
"""Build common Telegram inbound metadata payload."""
return {
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private",
"message_thread_id": getattr(message, "message_thread_id", None),
"is_forum": bool(getattr(message.chat, "is_forum", False)),
}
def _remember_thread_context(self, message) -> None:
"""Cache topic thread id by chat/message id for follow-up replies."""
message_thread_id = getattr(message, "message_thread_id", None)
if message_thread_id is None:
return
key = (str(message.chat_id), message.message_id)
self._message_threads[key] = message_thread_id
if len(self._message_threads) > 1000:
self._message_threads.pop(next(iter(self._message_threads)))
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Forward slash commands to the bus for unified handling in AgentLoop.""" """Forward slash commands to the bus for unified handling in AgentLoop."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
return return
message = update.message
user = update.effective_user
self._remember_thread_context(message)
await self._handle_message( await self._handle_message(
sender_id=self._sender_id(update.effective_user), sender_id=self._sender_id(user),
chat_id=str(update.message.chat_id), chat_id=str(message.chat_id),
content=update.message.text, content=message.text,
metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message),
) )
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
@@ -415,6 +476,7 @@ class TelegramChannel(BaseChannel):
user = update.effective_user user = update.effective_user
chat_id = message.chat_id chat_id = message.chat_id
sender_id = self._sender_id(user) sender_id = self._sender_id(user)
self._remember_thread_context(message)
# Store chat_id for replies # Store chat_id for replies
self._chat_ids[sender_id] = chat_id self._chat_ids[sender_id] = chat_id
@@ -485,6 +547,8 @@ class TelegramChannel(BaseChannel):
logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
metadata = self._build_message_metadata(message, user)
session_key = self._derive_topic_session_key(message)
# Telegram media groups: buffer briefly, forward as one aggregated turn. # Telegram media groups: buffer briefly, forward as one aggregated turn.
if media_group_id := getattr(message, "media_group_id", None): if media_group_id := getattr(message, "media_group_id", None):
@@ -493,11 +557,8 @@ class TelegramChannel(BaseChannel):
self._media_group_buffers[key] = { self._media_group_buffers[key] = {
"sender_id": sender_id, "chat_id": str_chat_id, "sender_id": sender_id, "chat_id": str_chat_id,
"contents": [], "media": [], "contents": [], "media": [],
"metadata": { "metadata": metadata,
"message_id": message.message_id, "user_id": user.id, "session_key": session_key,
"username": user.username, "first_name": user.first_name,
"is_group": message.chat.type != "private",
},
} }
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
buf = self._media_group_buffers[key] buf = self._media_group_buffers[key]
@@ -517,13 +578,8 @@ class TelegramChannel(BaseChannel):
chat_id=str_chat_id, chat_id=str_chat_id,
content=content, content=content,
media=media_paths, media=media_paths,
metadata={ metadata=metadata,
"message_id": message.message_id, session_key=session_key,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private"
}
) )
async def _flush_media_group(self, key: str) -> None: async def _flush_media_group(self, key: str) -> None:
@@ -537,6 +593,7 @@ class TelegramChannel(BaseChannel):
sender_id=buf["sender_id"], chat_id=buf["chat_id"], sender_id=buf["sender_id"], chat_id=buf["chat_id"],
content=content, media=list(dict.fromkeys(buf["media"])), content=content, media=list(dict.fromkeys(buf["media"])),
metadata=buf["metadata"], metadata=buf["metadata"],
session_key=buf.get("session_key"),
) )
finally: finally:
self._media_group_tasks.pop(key, None) self._media_group_tasks.pop(key, None)

View File

@@ -86,7 +86,6 @@ class TestMessageToolSuppressLogic:
assert result is not None assert result is not None
assert "Hello" in result.content assert "Hello" in result.content
@pytest.mark.asyncio
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})

View File

@@ -2,6 +2,7 @@ from types import SimpleNamespace
import pytest import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel from nanobot.channels.telegram import TelegramChannel
from nanobot.config.schema import TelegramConfig from nanobot.config.schema import TelegramConfig
@@ -24,12 +25,18 @@ class _FakeUpdater:
class _FakeBot: class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
async def get_me(self): async def get_me(self):
return SimpleNamespace(username="nanobot_test") return SimpleNamespace(username="nanobot_test")
async def set_my_commands(self, commands) -> None: async def set_my_commands(self, commands) -> None:
self.commands = commands self.commands = commands
async def send_message(self, **kwargs) -> None:
self.sent_messages.append(kwargs)
class _FakeApp: class _FakeApp:
def __init__(self, on_start_polling) -> None: def __init__(self, on_start_polling) -> None:
@@ -105,3 +112,51 @@ async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> No
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
assert builder.request_value is _FakeHTTPXRequest.instances[0] assert builder.request_value is _FakeHTTPXRequest.instances[0]
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0] assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
def test_derive_topic_session_key_uses_thread_id() -> None:
message = SimpleNamespace(
chat=SimpleNamespace(type="supergroup"),
chat_id=-100123,
message_thread_id=42,
)
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
@pytest.mark.asyncio
async def test_send_progress_keeps_message_in_topic() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"_progress": True, "message_thread_id": 42},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
@pytest.mark.asyncio
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
channel._message_threads[("123", 10)] = 42
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"message_id": 10},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10