fix(telegram): aggregate media-group images into a single inbound turn

This commit is contained in:
Kim
2026-02-27 12:08:48 +08:00
parent cab901b2fb
commit aa774733ea

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
from loguru import logger from loguru import logger
from telegram import BotCommand, Update, ReplyParameters from telegram import BotCommand, ReplyParameters, Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
@@ -21,60 +22,60 @@ def _markdown_to_telegram_html(text: str) -> str:
""" """
if not text: if not text:
return "" return ""
# 1. Extract and protect code blocks (preserve content from other processing) # 1. Extract and protect code blocks (preserve content from other processing)
code_blocks: list[str] = [] code_blocks: list[str] = []
def save_code_block(m: re.Match) -> str: def save_code_block(m: re.Match) -> str:
code_blocks.append(m.group(1)) code_blocks.append(m.group(1))
return f"\x00CB{len(code_blocks) - 1}\x00" return f"\x00CB{len(code_blocks) - 1}\x00"
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text) text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 2. Extract and protect inline code # 2. Extract and protect inline code
inline_codes: list[str] = [] inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str: def save_inline_code(m: re.Match) -> str:
inline_codes.append(m.group(1)) inline_codes.append(m.group(1))
return f"\x00IC{len(inline_codes) - 1}\x00" return f"\x00IC{len(inline_codes) - 1}\x00"
text = re.sub(r'`([^`]+)`', save_inline_code, text) text = re.sub(r'`([^`]+)`', save_inline_code, text)
# 3. Headers # Title -> just the title text # 3. Headers # Title -> just the title text
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE) text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
# 4. Blockquotes > text -> just the text (before HTML escaping) # 4. Blockquotes > text -> just the text (before HTML escaping)
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE) text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
# 5. Escape HTML special characters # 5. Escape HTML special characters
text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
# 6. Links [text](url) - must be before bold/italic to handle nested cases # 6. Links [text](url) - must be before bold/italic to handle nested cases
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text) text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
# 7. Bold **text** or __text__ # 7. Bold **text** or __text__
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text) text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text) text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
# 8. Italic _text_ (avoid matching inside words like some_var_name) # 8. Italic _text_ (avoid matching inside words like some_var_name)
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text) text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
# 9. Strikethrough ~~text~~ # 9. Strikethrough ~~text~~
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text) text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
# 10. Bullet lists - item -> • item # 10. Bullet lists - item -> • item
text = re.sub(r'^[-*]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^[-*]\s+', '', text, flags=re.MULTILINE)
# 11. Restore inline code with HTML tags # 11. Restore inline code with HTML tags
for i, code in enumerate(inline_codes): for i, code in enumerate(inline_codes):
# Escape HTML in code content # Escape HTML in code content
escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>") text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
# 12. Restore code blocks with HTML tags # 12. Restore code blocks with HTML tags
for i, code in enumerate(code_blocks): for i, code in enumerate(code_blocks):
# Escape HTML in code content # Escape HTML in code content
escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>") text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
return text return text
@@ -101,12 +102,12 @@ def _split_message(content: str, max_len: int = 4000) -> list[str]:
class TelegramChannel(BaseChannel): class TelegramChannel(BaseChannel):
""" """
Telegram channel using long polling. Telegram channel using long polling.
Simple and reliable - no webhook/public IP needed. Simple and reliable - no webhook/public IP needed.
""" """
name = "telegram" name = "telegram"
# Commands registered with Telegram's command menu # Commands registered with Telegram's command menu
BOT_COMMANDS = [ BOT_COMMANDS = [
BotCommand("start", "Start the bot"), BotCommand("start", "Start the bot"),
@@ -114,7 +115,7 @@ class TelegramChannel(BaseChannel):
BotCommand("stop", "Stop the current task"), BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
] ]
def __init__( def __init__(
self, self,
config: TelegramConfig, config: TelegramConfig,
@@ -127,15 +128,17 @@ class TelegramChannel(BaseChannel):
self._app: Application | None = None self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict[str, object]] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
if not self.config.token: if not self.config.token:
logger.error("Telegram bot token not configured") logger.error("Telegram bot token not configured")
return return
self._running = True self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs # Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
@@ -143,62 +146,69 @@ class TelegramChannel(BaseChannel):
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy) builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build() self._app = builder.build()
self._app.add_error_handler(self._on_error) self._app.add_error_handler(self._on_error)
# Add command handlers # Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command)) self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help)) self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents # Add message handler for text, photos, voice, documents
self._app.add_handler( self._app.add_handler(
MessageHandler( MessageHandler(
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL) (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
& ~filters.COMMAND, & ~filters.COMMAND,
self._on_message self._on_message
) )
) )
logger.info("Starting Telegram bot (polling mode)...") logger.info("Starting Telegram bot (polling mode)...")
# Initialize and start polling # Initialize and start polling
await self._app.initialize() await self._app.initialize()
await self._app.start() await self._app.start()
# Get bot info and register command menu # Get bot info and register command menu
bot_info = await self._app.bot.get_me() bot_info = await self._app.bot.get_me()
logger.info("Telegram bot @{} connected", bot_info.username) logger.info("Telegram bot @{} connected", bot_info.username)
try: try:
await self._app.bot.set_my_commands(self.BOT_COMMANDS) await self._app.bot.set_my_commands(self.BOT_COMMANDS)
logger.debug("Telegram bot commands registered") logger.debug("Telegram bot commands registered")
except Exception as e: except Exception as e:
logger.warning("Failed to register bot commands: {}", e) logger.warning("Failed to register bot commands: {}", e)
# Start polling (this runs until stopped) # Start polling (this runs until stopped)
await self._app.updater.start_polling( await self._app.updater.start_polling(
allowed_updates=["message"], allowed_updates=["message"],
drop_pending_updates=True # Ignore old messages on startup drop_pending_updates=True # Ignore old messages on startup
) )
# Keep running until stopped # Keep running until stopped
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the Telegram bot.""" """Stop the Telegram bot."""
self._running = False self._running = False
# Cancel all typing indicators # Cancel all typing indicators
for chat_id in list(self._typing_tasks): for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id) self._stop_typing(chat_id)
# Cancel buffered media-group flush tasks
for key, task in list(self._media_group_tasks.items()):
if task and not task.done():
task.cancel()
self._media_group_tasks.pop(key, None)
self._media_group_buffers.clear()
if self._app: if self._app:
logger.info("Stopping Telegram bot...") logger.info("Stopping Telegram bot...")
await self._app.updater.stop() await self._app.updater.stop()
await self._app.stop() await self._app.stop()
await self._app.shutdown() await self._app.shutdown()
self._app = None self._app = None
@staticmethod @staticmethod
def _get_media_type(path: str) -> str: def _get_media_type(path: str) -> str:
"""Guess media type from file extension.""" """Guess media type from file extension."""
@@ -246,7 +256,7 @@ class TelegramChannel(BaseChannel):
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document" param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
with open(media_path, 'rb') as f: with open(media_path, 'rb') as f:
await sender( await sender(
chat_id=chat_id, chat_id=chat_id,
**{param: f}, **{param: f},
reply_parameters=reply_params reply_parameters=reply_params
) )
@@ -265,8 +275,8 @@ class TelegramChannel(BaseChannel):
try: try:
html = _markdown_to_telegram_html(chunk) html = _markdown_to_telegram_html(chunk)
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=html, text=html,
parse_mode="HTML", parse_mode="HTML",
reply_parameters=reply_params reply_parameters=reply_params
) )
@@ -274,13 +284,13 @@ class TelegramChannel(BaseChannel):
logger.warning("HTML parse failed, falling back to plain text: {}", e) logger.warning("HTML parse failed, falling back to plain text: {}", e)
try: try:
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=chunk, text=chunk,
reply_parameters=reply_params reply_parameters=reply_params
) )
except Exception as e2: except Exception as e2:
logger.error("Error sending Telegram message: {}", e2) logger.error("Error sending Telegram message: {}", e2)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
@@ -319,34 +329,34 @@ class TelegramChannel(BaseChannel):
chat_id=str(update.message.chat_id), chat_id=str(update.message.chat_id),
content=update.message.text, content=update.message.text,
) )
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming messages (text, photos, voice, documents).""" """Handle incoming messages (text, photos, voice, documents)."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
return return
message = update.message message = update.message
user = update.effective_user user = update.effective_user
chat_id = message.chat_id chat_id = message.chat_id
sender_id = self._sender_id(user) sender_id = self._sender_id(user)
# Store chat_id for replies # Store chat_id for replies
self._chat_ids[sender_id] = chat_id self._chat_ids[sender_id] = chat_id
# Build content from text and/or media # Build content from text and/or media
content_parts = [] content_parts = []
media_paths = [] media_paths = []
# Text content # Text content
if message.text: if message.text:
content_parts.append(message.text) content_parts.append(message.text)
if message.caption: if message.caption:
content_parts.append(message.caption) content_parts.append(message.caption)
# Handle media files # Handle media files
media_file = None media_file = None
media_type = None media_type = None
if message.photo: if message.photo:
media_file = message.photo[-1] # Largest photo media_file = message.photo[-1] # Largest photo
media_type = "image" media_type = "image"
@@ -359,23 +369,23 @@ class TelegramChannel(BaseChannel):
elif message.document: elif message.document:
media_file = message.document media_file = message.document
media_type = "file" media_type = "file"
# Download media if present # Download media if present
if media_file and self._app: if media_file and self._app:
try: try:
file = await self._app.bot.get_file(media_file.file_id) file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
# Save to workspace/media/ # Save to workspace/media/
from pathlib import Path from pathlib import Path
media_dir = Path.home() / ".nanobot" / "media" media_dir = Path.home() / ".nanobot" / "media"
media_dir.mkdir(parents=True, exist_ok=True) media_dir.mkdir(parents=True, exist_ok=True)
file_path = media_dir / f"{media_file.file_id[:16]}{ext}" file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path)) await file.download_to_drive(str(file_path))
media_paths.append(str(file_path)) media_paths.append(str(file_path))
# Handle voice transcription # Handle voice transcription
if media_type == "voice" or media_type == "audio": if media_type == "voice" or media_type == "audio":
from nanobot.providers.transcription import GroqTranscriptionProvider from nanobot.providers.transcription import GroqTranscriptionProvider
@@ -388,21 +398,60 @@ class TelegramChannel(BaseChannel):
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
else: else:
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
logger.debug("Downloaded {} to {}", media_type, file_path) logger.debug("Downloaded {} to {}", media_type, file_path)
except Exception as e: except Exception as e:
logger.error("Failed to download media: {}", e) logger.error("Failed to download media: {}", e)
content_parts.append(f"[{media_type}: download failed]") content_parts.append(f"[{media_type}: download failed]")
content = "\n".join(content_parts) if content_parts else "[empty message]" content = "\n".join(content_parts) if content_parts else "[empty message]"
logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
# Telegram media groups arrive as multiple messages sharing media_group_id.
# Buffer briefly and forward as one aggregated turn.
media_group_id = getattr(message, "media_group_id", None)
if media_group_id:
group_key = f"{str_chat_id}:{media_group_id}"
buffer = self._media_group_buffers.get(group_key)
if not buffer:
buffer = {
"sender_id": sender_id,
"chat_id": str_chat_id,
"contents": [],
"media": [],
"metadata": {
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private",
"media_group_id": media_group_id,
},
}
self._media_group_buffers[group_key] = buffer
self._start_typing(str_chat_id)
if content and content != "[empty message]":
cast_contents = buffer["contents"]
if isinstance(cast_contents, list):
cast_contents.append(content)
cast_media = buffer["media"]
if isinstance(cast_media, list):
cast_media.extend(media_paths)
# Start one delayed flush task per media group.
if group_key not in self._media_group_tasks:
self._media_group_tasks[group_key] = asyncio.create_task(
self._flush_media_group(group_key)
)
return
# Start typing indicator before processing # Start typing indicator before processing
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
# Forward to the message bus # Forward to the message bus
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
@@ -417,19 +466,56 @@ class TelegramChannel(BaseChannel):
"is_group": message.chat.type != "private" "is_group": message.chat.type != "private"
} }
) )
async def _flush_media_group(self, group_key: str, delay_s: float = 0.6) -> None:
"""Flush buffered Telegram media-group messages as one aggregated turn."""
try:
await asyncio.sleep(delay_s)
buffer = self._media_group_buffers.pop(group_key, None)
if not buffer:
return
sender_id = str(buffer.get("sender_id", ""))
chat_id = str(buffer.get("chat_id", ""))
contents = buffer.get("contents")
media = buffer.get("media")
metadata = buffer.get("metadata")
content_parts = [c for c in (contents if isinstance(contents, list) else []) if isinstance(c, str) and c]
media_paths = [m for m in (media if isinstance(media, list) else []) if isinstance(m, str) and m]
# De-duplicate while preserving order
seen = set()
unique_media: list[str] = []
for m in media_paths:
if m in seen:
continue
seen.add(m)
unique_media.append(m)
content = "\n".join(content_parts) if content_parts else "[empty message]"
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=content,
media=unique_media,
metadata=metadata if isinstance(metadata, dict) else {},
)
finally:
self._media_group_tasks.pop(group_key, None)
def _start_typing(self, chat_id: str) -> None: def _start_typing(self, chat_id: str) -> None:
"""Start sending 'typing...' indicator for a chat.""" """Start sending 'typing...' indicator for a chat."""
# Cancel any existing typing task for this chat # Cancel any existing typing task for this chat
self._stop_typing(chat_id) self._stop_typing(chat_id)
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
def _stop_typing(self, chat_id: str) -> None: def _stop_typing(self, chat_id: str) -> None:
"""Stop the typing indicator for a chat.""" """Stop the typing indicator for a chat."""
task = self._typing_tasks.pop(chat_id, None) task = self._typing_tasks.pop(chat_id, None)
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
async def _typing_loop(self, chat_id: str) -> None: async def _typing_loop(self, chat_id: str) -> None:
"""Repeatedly send 'typing' action until cancelled.""" """Repeatedly send 'typing' action until cancelled."""
try: try:
@@ -440,7 +526,7 @@ class TelegramChannel(BaseChannel):
pass pass
except Exception as e: except Exception as e:
logger.debug("Typing indicator stopped for {}: {}", chat_id, e) logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Log polling / handler errors instead of silently swallowing them.""" """Log polling / handler errors instead of silently swallowing them."""
logger.error("Telegram error: {}", context.error) logger.error("Telegram error: {}", context.error)
@@ -454,6 +540,6 @@ class TelegramChannel(BaseChannel):
} }
if mime_type in ext_map: if mime_type in ext_map:
return ext_map[mime_type] return ext_map[mime_type]
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
return type_map.get(media_type, "") return type_map.get(media_type, "")