diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 35e519a..0a52427 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -96,7 +96,7 @@ class MessageTool(Tool): media=media or [], metadata={ "message_id": message_id, - } + }, ) try: diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 2c92c2d..6c57288 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -177,6 +177,7 @@ class TelegramChannel(BaseChannel): self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._media_group_buffers: dict[str, dict] = {} self._media_group_tasks: dict[str, asyncio.Task] = {} + self._message_threads: dict[tuple[str, int], int] = {} async def start(self) -> None: """Start the Telegram bot with long polling.""" @@ -286,10 +287,16 @@ class TelegramChannel(BaseChannel): except ValueError: logger.error("Invalid chat_id: {}", msg.chat_id) return + reply_to_message_id = msg.metadata.get("message_id") + message_thread_id = msg.metadata.get("message_thread_id") + if message_thread_id is None and reply_to_message_id is not None: + message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id)) + thread_kwargs = {} + if message_thread_id is not None: + thread_kwargs["message_thread_id"] = message_thread_id reply_params = None if self.config.reply_to_message: - reply_to_message_id = msg.metadata.get("message_id") if reply_to_message_id: reply_params = ReplyParameters( message_id=reply_to_message_id, @@ -310,7 +317,8 @@ class TelegramChannel(BaseChannel): await sender( chat_id=chat_id, **{param: f}, - reply_parameters=reply_params + reply_parameters=reply_params, + **thread_kwargs, ) except Exception as e: filename = media_path.rsplit("/", 1)[-1] @@ -318,7 +326,8 @@ class TelegramChannel(BaseChannel): await self._app.bot.send_message( chat_id=chat_id, text=f"[Failed to send: {filename}]", - reply_parameters=reply_params + reply_parameters=reply_params, + **thread_kwargs, ) # Send text content @@ -328,28 +337,44 @@ class TelegramChannel(BaseChannel): for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN): # Final response: simulate streaming via draft, then persist if not is_progress: - await self._send_with_streaming(chat_id, chunk, reply_params) + await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs) else: - await self._send_text(chat_id, chunk, reply_params) + await self._send_text(chat_id, chunk, reply_params, thread_kwargs) - async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None: + async def _send_text( + self, + chat_id: int, + text: str, + reply_params=None, + thread_kwargs: dict | None = None, + ) -> None: """Send a plain text message with HTML fallback.""" try: html = _markdown_to_telegram_html(text) await self._app.bot.send_message( chat_id=chat_id, text=html, parse_mode="HTML", reply_parameters=reply_params, + **(thread_kwargs or {}), ) except Exception as e: logger.warning("HTML parse failed, falling back to plain text: {}", e) try: await self._app.bot.send_message( - chat_id=chat_id, text=text, reply_parameters=reply_params, + chat_id=chat_id, + text=text, + reply_parameters=reply_params, + **(thread_kwargs or {}), ) except Exception as e2: logger.error("Error sending Telegram message: {}", e2) - async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None: + async def _send_with_streaming( + self, + chat_id: int, + text: str, + reply_params=None, + thread_kwargs: dict | None = None, + ) -> None: """Simulate streaming via send_message_draft, then persist with send_message.""" draft_id = int(time.time() * 1000) % (2**31) try: @@ -365,7 +390,7 @@ class TelegramChannel(BaseChannel): await asyncio.sleep(0.15) except Exception: pass - await self._send_text(chat_id, text, reply_params) + await self._send_text(chat_id, text, reply_params, thread_kwargs) async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /start command.""" @@ -396,14 +421,50 @@ class TelegramChannel(BaseChannel): sid = str(user.id) return f"{sid}|{user.username}" if user.username else sid + @staticmethod + def _derive_topic_session_key(message) -> str | None: + """Derive topic-scoped session key for non-private Telegram chats.""" + message_thread_id = getattr(message, "message_thread_id", None) + if message.chat.type == "private" or message_thread_id is None: + return None + return f"telegram:{message.chat_id}:topic:{message_thread_id}" + + @staticmethod + def _build_message_metadata(message, user) -> dict: + """Build common Telegram inbound metadata payload.""" + return { + "message_id": message.message_id, + "user_id": user.id, + "username": user.username, + "first_name": user.first_name, + "is_group": message.chat.type != "private", + "message_thread_id": getattr(message, "message_thread_id", None), + "is_forum": bool(getattr(message.chat, "is_forum", False)), + } + + def _remember_thread_context(self, message) -> None: + """Cache topic thread id by chat/message id for follow-up replies.""" + message_thread_id = getattr(message, "message_thread_id", None) + if message_thread_id is None: + return + key = (str(message.chat_id), message.message_id) + self._message_threads[key] = message_thread_id + if len(self._message_threads) > 1000: + self._message_threads.pop(next(iter(self._message_threads))) + async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Forward slash commands to the bus for unified handling in AgentLoop.""" if not update.message or not update.effective_user: return + message = update.message + user = update.effective_user + self._remember_thread_context(message) await self._handle_message( - sender_id=self._sender_id(update.effective_user), - chat_id=str(update.message.chat_id), - content=update.message.text, + sender_id=self._sender_id(user), + chat_id=str(message.chat_id), + content=message.text, + metadata=self._build_message_metadata(message, user), + session_key=self._derive_topic_session_key(message), ) async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -415,6 +476,7 @@ class TelegramChannel(BaseChannel): user = update.effective_user chat_id = message.chat_id sender_id = self._sender_id(user) + self._remember_thread_context(message) # Store chat_id for replies self._chat_ids[sender_id] = chat_id @@ -485,6 +547,8 @@ class TelegramChannel(BaseChannel): logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) str_chat_id = str(chat_id) + metadata = self._build_message_metadata(message, user) + session_key = self._derive_topic_session_key(message) # Telegram media groups: buffer briefly, forward as one aggregated turn. if media_group_id := getattr(message, "media_group_id", None): @@ -493,11 +557,8 @@ class TelegramChannel(BaseChannel): self._media_group_buffers[key] = { "sender_id": sender_id, "chat_id": str_chat_id, "contents": [], "media": [], - "metadata": { - "message_id": message.message_id, "user_id": user.id, - "username": user.username, "first_name": user.first_name, - "is_group": message.chat.type != "private", - }, + "metadata": metadata, + "session_key": session_key, } self._start_typing(str_chat_id) buf = self._media_group_buffers[key] @@ -517,13 +578,8 @@ class TelegramChannel(BaseChannel): chat_id=str_chat_id, content=content, media=media_paths, - metadata={ - "message_id": message.message_id, - "user_id": user.id, - "username": user.username, - "first_name": user.first_name, - "is_group": message.chat.type != "private" - } + metadata=metadata, + session_key=session_key, ) async def _flush_media_group(self, key: str) -> None: @@ -537,6 +593,7 @@ class TelegramChannel(BaseChannel): sender_id=buf["sender_id"], chat_id=buf["chat_id"], content=content, media=list(dict.fromkeys(buf["media"])), metadata=buf["metadata"], + session_key=buf.get("session_key"), ) finally: self._media_group_tasks.pop(key, None) diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py index f5e65c9..63b0fd1 100644 --- a/tests/test_message_tool_suppress.py +++ b/tests/test_message_tool_suppress.py @@ -86,7 +86,6 @@ class TestMessageToolSuppressLogic: assert result is not None assert "Hello" in result.content - @pytest.mark.asyncio async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) diff --git a/tests/test_telegram_channel.py b/tests/test_telegram_channel.py index 3bacf96..6100549 100644 --- a/tests/test_telegram_channel.py +++ b/tests/test_telegram_channel.py @@ -2,6 +2,7 @@ from types import SimpleNamespace import pytest +from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.telegram import TelegramChannel from nanobot.config.schema import TelegramConfig @@ -24,12 +25,18 @@ class _FakeUpdater: class _FakeBot: + def __init__(self) -> None: + self.sent_messages: list[dict] = [] + async def get_me(self): return SimpleNamespace(username="nanobot_test") async def set_my_commands(self, commands) -> None: self.commands = commands + async def send_message(self, **kwargs) -> None: + self.sent_messages.append(kwargs) + class _FakeApp: def __init__(self, on_start_polling) -> None: @@ -105,3 +112,51 @@ async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> No assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy assert builder.request_value is _FakeHTTPXRequest.instances[0] assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0] + + +def test_derive_topic_session_key_uses_thread_id() -> None: + message = SimpleNamespace( + chat=SimpleNamespace(type="supergroup"), + chat_id=-100123, + message_thread_id=42, + ) + + assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42" + + +@pytest.mark.asyncio +async def test_send_progress_keeps_message_in_topic() -> None: + config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]) + channel = TelegramChannel(config, MessageBus()) + channel._app = _FakeApp(lambda: None) + + await channel.send( + OutboundMessage( + channel="telegram", + chat_id="123", + content="hello", + metadata={"_progress": True, "message_thread_id": 42}, + ) + ) + + assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42 + + +@pytest.mark.asyncio +async def test_send_reply_infers_topic_from_message_id_cache() -> None: + config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True) + channel = TelegramChannel(config, MessageBus()) + channel._app = _FakeApp(lambda: None) + channel._message_threads[("123", 10)] = 42 + + await channel.send( + OutboundMessage( + channel="telegram", + chat_id="123", + content="hello", + metadata={"message_id": 10}, + ) + ) + + assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42 + assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10