Merge pull request #1482 from contributors/telegram-topic-support
This commit is contained in:
@@ -96,7 +96,7 @@ class MessageTool(Tool):
|
|||||||
media=media or [],
|
media=media or [],
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
self._media_group_buffers: dict[str, dict] = {}
|
self._media_group_buffers: dict[str, dict] = {}
|
||||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._message_threads: dict[tuple[str, int], int] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Telegram bot with long polling."""
|
"""Start the Telegram bot with long polling."""
|
||||||
@@ -286,10 +287,16 @@ class TelegramChannel(BaseChannel):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
reply_to_message_id = msg.metadata.get("message_id")
|
||||||
|
message_thread_id = msg.metadata.get("message_thread_id")
|
||||||
|
if message_thread_id is None and reply_to_message_id is not None:
|
||||||
|
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
|
||||||
|
thread_kwargs = {}
|
||||||
|
if message_thread_id is not None:
|
||||||
|
thread_kwargs["message_thread_id"] = message_thread_id
|
||||||
|
|
||||||
reply_params = None
|
reply_params = None
|
||||||
if self.config.reply_to_message:
|
if self.config.reply_to_message:
|
||||||
reply_to_message_id = msg.metadata.get("message_id")
|
|
||||||
if reply_to_message_id:
|
if reply_to_message_id:
|
||||||
reply_params = ReplyParameters(
|
reply_params = ReplyParameters(
|
||||||
message_id=reply_to_message_id,
|
message_id=reply_to_message_id,
|
||||||
@@ -310,7 +317,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
await sender(
|
await sender(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
**{param: f},
|
**{param: f},
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params,
|
||||||
|
**thread_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
filename = media_path.rsplit("/", 1)[-1]
|
filename = media_path.rsplit("/", 1)[-1]
|
||||||
@@ -318,7 +326,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=f"[Failed to send: {filename}]",
|
text=f"[Failed to send: {filename}]",
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params,
|
||||||
|
**thread_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send text content
|
# Send text content
|
||||||
@@ -328,28 +337,44 @@ class TelegramChannel(BaseChannel):
|
|||||||
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||||
# Final response: simulate streaming via draft, then persist
|
# Final response: simulate streaming via draft, then persist
|
||||||
if not is_progress:
|
if not is_progress:
|
||||||
await self._send_with_streaming(chat_id, chunk, reply_params)
|
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
||||||
else:
|
else:
|
||||||
await self._send_text(chat_id, chunk, reply_params)
|
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||||
|
|
||||||
async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None:
|
async def _send_text(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
text: str,
|
||||||
|
reply_params=None,
|
||||||
|
thread_kwargs: dict | None = None,
|
||||||
|
) -> None:
|
||||||
"""Send a plain text message with HTML fallback."""
|
"""Send a plain text message with HTML fallback."""
|
||||||
try:
|
try:
|
||||||
html = _markdown_to_telegram_html(text)
|
html = _markdown_to_telegram_html(text)
|
||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id, text=html, parse_mode="HTML",
|
chat_id=chat_id, text=html, parse_mode="HTML",
|
||||||
reply_parameters=reply_params,
|
reply_parameters=reply_params,
|
||||||
|
**(thread_kwargs or {}),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
try:
|
try:
|
||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id, text=text, reply_parameters=reply_params,
|
chat_id=chat_id,
|
||||||
|
text=text,
|
||||||
|
reply_parameters=reply_params,
|
||||||
|
**(thread_kwargs or {}),
|
||||||
)
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.error("Error sending Telegram message: {}", e2)
|
logger.error("Error sending Telegram message: {}", e2)
|
||||||
|
|
||||||
async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None:
|
async def _send_with_streaming(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
text: str,
|
||||||
|
reply_params=None,
|
||||||
|
thread_kwargs: dict | None = None,
|
||||||
|
) -> None:
|
||||||
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
||||||
draft_id = int(time.time() * 1000) % (2**31)
|
draft_id = int(time.time() * 1000) % (2**31)
|
||||||
try:
|
try:
|
||||||
@@ -365,7 +390,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
await asyncio.sleep(0.15)
|
await asyncio.sleep(0.15)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await self._send_text(chat_id, text, reply_params)
|
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
||||||
|
|
||||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
@@ -396,14 +421,50 @@ class TelegramChannel(BaseChannel):
|
|||||||
sid = str(user.id)
|
sid = str(user.id)
|
||||||
return f"{sid}|{user.username}" if user.username else sid
|
return f"{sid}|{user.username}" if user.username else sid
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _derive_topic_session_key(message) -> str | None:
|
||||||
|
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||||
|
message_thread_id = getattr(message, "message_thread_id", None)
|
||||||
|
if message.chat.type == "private" or message_thread_id is None:
|
||||||
|
return None
|
||||||
|
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_message_metadata(message, user) -> dict:
|
||||||
|
"""Build common Telegram inbound metadata payload."""
|
||||||
|
return {
|
||||||
|
"message_id": message.message_id,
|
||||||
|
"user_id": user.id,
|
||||||
|
"username": user.username,
|
||||||
|
"first_name": user.first_name,
|
||||||
|
"is_group": message.chat.type != "private",
|
||||||
|
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||||
|
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _remember_thread_context(self, message) -> None:
|
||||||
|
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||||
|
message_thread_id = getattr(message, "message_thread_id", None)
|
||||||
|
if message_thread_id is None:
|
||||||
|
return
|
||||||
|
key = (str(message.chat_id), message.message_id)
|
||||||
|
self._message_threads[key] = message_thread_id
|
||||||
|
if len(self._message_threads) > 1000:
|
||||||
|
self._message_threads.pop(next(iter(self._message_threads)))
|
||||||
|
|
||||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
return
|
return
|
||||||
|
message = update.message
|
||||||
|
user = update.effective_user
|
||||||
|
self._remember_thread_context(message)
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=self._sender_id(update.effective_user),
|
sender_id=self._sender_id(user),
|
||||||
chat_id=str(update.message.chat_id),
|
chat_id=str(message.chat_id),
|
||||||
content=update.message.text,
|
content=message.text,
|
||||||
|
metadata=self._build_message_metadata(message, user),
|
||||||
|
session_key=self._derive_topic_session_key(message),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
@@ -415,6 +476,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
chat_id = message.chat_id
|
chat_id = message.chat_id
|
||||||
sender_id = self._sender_id(user)
|
sender_id = self._sender_id(user)
|
||||||
|
self._remember_thread_context(message)
|
||||||
|
|
||||||
# Store chat_id for replies
|
# Store chat_id for replies
|
||||||
self._chat_ids[sender_id] = chat_id
|
self._chat_ids[sender_id] = chat_id
|
||||||
@@ -485,6 +547,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
|
metadata = self._build_message_metadata(message, user)
|
||||||
|
session_key = self._derive_topic_session_key(message)
|
||||||
|
|
||||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||||
if media_group_id := getattr(message, "media_group_id", None):
|
if media_group_id := getattr(message, "media_group_id", None):
|
||||||
@@ -493,11 +557,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._media_group_buffers[key] = {
|
self._media_group_buffers[key] = {
|
||||||
"sender_id": sender_id, "chat_id": str_chat_id,
|
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||||
"contents": [], "media": [],
|
"contents": [], "media": [],
|
||||||
"metadata": {
|
"metadata": metadata,
|
||||||
"message_id": message.message_id, "user_id": user.id,
|
"session_key": session_key,
|
||||||
"username": user.username, "first_name": user.first_name,
|
|
||||||
"is_group": message.chat.type != "private",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
self._start_typing(str_chat_id)
|
self._start_typing(str_chat_id)
|
||||||
buf = self._media_group_buffers[key]
|
buf = self._media_group_buffers[key]
|
||||||
@@ -517,13 +578,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
chat_id=str_chat_id,
|
chat_id=str_chat_id,
|
||||||
content=content,
|
content=content,
|
||||||
media=media_paths,
|
media=media_paths,
|
||||||
metadata={
|
metadata=metadata,
|
||||||
"message_id": message.message_id,
|
session_key=session_key,
|
||||||
"user_id": user.id,
|
|
||||||
"username": user.username,
|
|
||||||
"first_name": user.first_name,
|
|
||||||
"is_group": message.chat.type != "private"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _flush_media_group(self, key: str) -> None:
|
async def _flush_media_group(self, key: str) -> None:
|
||||||
@@ -537,6 +593,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||||
content=content, media=list(dict.fromkeys(buf["media"])),
|
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||||
metadata=buf["metadata"],
|
metadata=buf["metadata"],
|
||||||
|
session_key=buf.get("session_key"),
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._media_group_tasks.pop(key, None)
|
self._media_group_tasks.pop(key, None)
|
||||||
|
|||||||
@@ -86,7 +86,6 @@ class TestMessageToolSuppressLogic:
|
|||||||
assert result is not None
|
assert result is not None
|
||||||
assert "Hello" in result.content
|
assert "Hello" in result.content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.telegram import TelegramChannel
|
from nanobot.channels.telegram import TelegramChannel
|
||||||
from nanobot.config.schema import TelegramConfig
|
from nanobot.config.schema import TelegramConfig
|
||||||
@@ -24,12 +25,18 @@ class _FakeUpdater:
|
|||||||
|
|
||||||
|
|
||||||
class _FakeBot:
|
class _FakeBot:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent_messages: list[dict] = []
|
||||||
|
|
||||||
async def get_me(self):
|
async def get_me(self):
|
||||||
return SimpleNamespace(username="nanobot_test")
|
return SimpleNamespace(username="nanobot_test")
|
||||||
|
|
||||||
async def set_my_commands(self, commands) -> None:
|
async def set_my_commands(self, commands) -> None:
|
||||||
self.commands = commands
|
self.commands = commands
|
||||||
|
|
||||||
|
async def send_message(self, **kwargs) -> None:
|
||||||
|
self.sent_messages.append(kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _FakeApp:
|
class _FakeApp:
|
||||||
def __init__(self, on_start_polling) -> None:
|
def __init__(self, on_start_polling) -> None:
|
||||||
@@ -105,3 +112,51 @@ async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> No
|
|||||||
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
||||||
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
||||||
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||||
|
message = SimpleNamespace(
|
||||||
|
chat=SimpleNamespace(type="supergroup"),
|
||||||
|
chat_id=-100123,
|
||||||
|
message_thread_id=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||||
|
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||||
|
channel = TelegramChannel(config, MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"_progress": True, "message_thread_id": 42},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||||
|
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||||
|
channel = TelegramChannel(config, MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
channel._message_threads[("123", 10)] = 42
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": 10},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||||
|
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||||
|
|||||||
Reference in New Issue
Block a user