fix(telegram): aggregate media-group images into a single inbound turn

This commit is contained in:
Kim
2026-02-27 12:08:48 +08:00
parent cab901b2fb
commit aa774733ea

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
from loguru import logger from loguru import logger
from telegram import BotCommand, Update, ReplyParameters from telegram import BotCommand, ReplyParameters, Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
@@ -127,6 +128,8 @@ class TelegramChannel(BaseChannel):
self._app: Application | None = None self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict[str, object]] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
@@ -192,6 +195,13 @@ class TelegramChannel(BaseChannel):
for chat_id in list(self._typing_tasks): for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id) self._stop_typing(chat_id)
# Cancel buffered media-group flush tasks
for key, task in list(self._media_group_tasks.items()):
if task and not task.done():
task.cancel()
self._media_group_tasks.pop(key, None)
self._media_group_buffers.clear()
if self._app: if self._app:
logger.info("Stopping Telegram bot...") logger.info("Stopping Telegram bot...")
await self._app.updater.stop() await self._app.updater.stop()
@@ -400,6 +410,45 @@ class TelegramChannel(BaseChannel):
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
# Telegram media groups arrive as multiple messages sharing media_group_id.
# Buffer briefly and forward as one aggregated turn.
media_group_id = getattr(message, "media_group_id", None)
if media_group_id:
group_key = f"{str_chat_id}:{media_group_id}"
buffer = self._media_group_buffers.get(group_key)
if not buffer:
buffer = {
"sender_id": sender_id,
"chat_id": str_chat_id,
"contents": [],
"media": [],
"metadata": {
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private",
"media_group_id": media_group_id,
},
}
self._media_group_buffers[group_key] = buffer
self._start_typing(str_chat_id)
if content and content != "[empty message]":
cast_contents = buffer["contents"]
if isinstance(cast_contents, list):
cast_contents.append(content)
cast_media = buffer["media"]
if isinstance(cast_media, list):
cast_media.extend(media_paths)
# Start one delayed flush task per media group.
if group_key not in self._media_group_tasks:
self._media_group_tasks[group_key] = asyncio.create_task(
self._flush_media_group(group_key)
)
return
# Start typing indicator before processing # Start typing indicator before processing
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
@@ -418,6 +467,43 @@ class TelegramChannel(BaseChannel):
} }
) )
async def _flush_media_group(self, group_key: str, delay_s: float = 0.6) -> None:
"""Flush buffered Telegram media-group messages as one aggregated turn."""
try:
await asyncio.sleep(delay_s)
buffer = self._media_group_buffers.pop(group_key, None)
if not buffer:
return
sender_id = str(buffer.get("sender_id", ""))
chat_id = str(buffer.get("chat_id", ""))
contents = buffer.get("contents")
media = buffer.get("media")
metadata = buffer.get("metadata")
content_parts = [c for c in (contents if isinstance(contents, list) else []) if isinstance(c, str) and c]
media_paths = [m for m in (media if isinstance(media, list) else []) if isinstance(m, str) and m]
# De-duplicate while preserving order
seen = set()
unique_media: list[str] = []
for m in media_paths:
if m in seen:
continue
seen.add(m)
unique_media.append(m)
content = "\n".join(content_parts) if content_parts else "[empty message]"
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=content,
media=unique_media,
metadata=metadata if isinstance(metadata, dict) else {},
)
finally:
self._media_group_tasks.pop(group_key, None)
def _start_typing(self, chat_id: str) -> None: def _start_typing(self, chat_id: str) -> None:
"""Start sending 'typing...' indicator for a chat.""" """Start sending 'typing...' indicator for a chat."""
# Cancel any existing typing task for this chat # Cancel any existing typing task for this chat