Merge branch 'main' into pr-1361
This commit is contained in:
@@ -12,17 +12,17 @@ from nanobot.bus.queue import MessageBus
|
||||
class BaseChannel(ABC):
|
||||
"""
|
||||
Abstract base class for chat channel implementations.
|
||||
|
||||
|
||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||
to integrate with the nanobot message bus.
|
||||
"""
|
||||
|
||||
|
||||
name: str = "base"
|
||||
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
|
||||
Args:
|
||||
config: Channel-specific configuration.
|
||||
bus: The message bus for communication.
|
||||
@@ -30,50 +30,50 @@ class BaseChannel(ABC):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the channel and begin listening for messages.
|
||||
|
||||
|
||||
This should be a long-running async task that:
|
||||
1. Connects to the chat platform
|
||||
2. Listens for incoming messages
|
||||
3. Forwards messages to the bus via _handle_message()
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""
|
||||
Send a message through this channel.
|
||||
|
||||
|
||||
Args:
|
||||
msg: The message to send.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""
|
||||
Check if a sender is allowed to use this bot.
|
||||
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise.
|
||||
"""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
|
||||
|
||||
# If no allow list, allow everyone
|
||||
if not allow_list:
|
||||
return True
|
||||
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str in allow_list:
|
||||
return True
|
||||
@@ -82,7 +82,7 @@ class BaseChannel(ABC):
|
||||
if part and part in allow_list:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
@@ -94,9 +94,9 @@ class BaseChannel(ABC):
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
@@ -112,7 +112,7 @@ class BaseChannel(ABC):
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
@@ -122,9 +122,9 @@ class BaseChannel(ABC):
|
||||
metadata=metadata or {},
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the channel is running."""
|
||||
|
||||
@@ -9,8 +9,8 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from loguru import logger
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@@ -19,11 +19,11 @@ from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
try:
|
||||
from dingtalk_stream import (
|
||||
DingTalkStreamClient,
|
||||
Credential,
|
||||
AckMessage,
|
||||
CallbackHandler,
|
||||
CallbackMessage,
|
||||
AckMessage,
|
||||
Credential,
|
||||
DingTalkStreamClient,
|
||||
)
|
||||
from dingtalk_stream.chatbot import ChatbotMessage
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
@@ -23,12 +23,11 @@ try:
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
Emoji,
|
||||
GetFileRequest,
|
||||
GetMessageResourceRequest,
|
||||
P2ImMessageReceiveV1,
|
||||
)
|
||||
@@ -70,7 +69,7 @@ def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
||||
def _extract_interactive_content(content: dict) -> list[str]:
|
||||
"""Recursively extract text and links from interactive card content."""
|
||||
parts = []
|
||||
|
||||
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
@@ -104,19 +103,19 @@ def _extract_interactive_content(content: dict) -> list[str]:
|
||||
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||
if header_text:
|
||||
parts.append(f"title: {header_text}")
|
||||
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_element_content(element: dict) -> list[str]:
|
||||
"""Extract content from a single card element."""
|
||||
parts = []
|
||||
|
||||
|
||||
if not isinstance(element, dict):
|
||||
return parts
|
||||
|
||||
|
||||
tag = element.get("tag", "")
|
||||
|
||||
|
||||
if tag in ("markdown", "lark_md"):
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
@@ -177,17 +176,17 @@ def _extract_element_content(element: dict) -> list[str]:
|
||||
else:
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||
"""Extract text and image keys from Feishu post (rich text) message content.
|
||||
|
||||
|
||||
Supports two formats:
|
||||
1. Direct format: {"title": "...", "content": [...]}
|
||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||
|
||||
|
||||
Returns:
|
||||
(text, image_keys) - extracted text and list of image keys
|
||||
"""
|
||||
@@ -220,7 +219,7 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||
image_keys.append(img_key)
|
||||
text = " ".join(text_parts).strip() if text_parts else None
|
||||
return text, image_keys
|
||||
|
||||
|
||||
# Compatible with both shapes:
|
||||
# 1) {"post": {"zh_cn": {...}}}
|
||||
# 2) {"zh_cn": {...}} or {"title": "...", "content": [...]}
|
||||
@@ -233,7 +232,7 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||
text, images = extract_from_lang(post_root)
|
||||
if text or images:
|
||||
return text or "", images
|
||||
|
||||
|
||||
# Try localized format
|
||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||
lang_content = post_root.get(lang_key)
|
||||
@@ -247,13 +246,13 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||
text, images = extract_from_lang(value)
|
||||
if text or images:
|
||||
return text or "", images
|
||||
|
||||
|
||||
return "", []
|
||||
|
||||
|
||||
def _extract_post_text(content_json: dict) -> str:
|
||||
"""Extract plain text from Feishu post (rich text) message content.
|
||||
|
||||
|
||||
Legacy wrapper for _extract_post_content, returns only text.
|
||||
"""
|
||||
text, _ = _extract_post_content(content_json)
|
||||
@@ -263,17 +262,17 @@ def _extract_post_text(content_json: dict) -> str:
|
||||
class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
Feishu/Lark channel using WebSocket long connection.
|
||||
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
|
||||
Requires:
|
||||
- App ID and App Secret from Feishu Open Platform
|
||||
- Bot capability enabled
|
||||
- Event subscription enabled (im.message.receive_v1)
|
||||
"""
|
||||
|
||||
|
||||
name = "feishu"
|
||||
|
||||
|
||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig = config
|
||||
@@ -282,27 +281,27 @@ class FeishuChannel(BaseChannel):
|
||||
self._ws_thread: threading.Thread | None = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Feishu bot with WebSocket long connection."""
|
||||
if not FEISHU_AVAILABLE:
|
||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
||||
return
|
||||
|
||||
|
||||
if not self.config.app_id or not self.config.app_secret:
|
||||
logger.error("Feishu app_id and app_secret not configured")
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
|
||||
# Create Lark client for sending messages
|
||||
self._client = lark.Client.builder() \
|
||||
.app_id(self.config.app_id) \
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
.build()
|
||||
|
||||
|
||||
# Create event handler (only register message receive, ignore other events)
|
||||
event_handler = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
@@ -310,7 +309,7 @@ class FeishuChannel(BaseChannel):
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
).build()
|
||||
|
||||
|
||||
# Create WebSocket client for long connection
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
@@ -318,7 +317,7 @@ class FeishuChannel(BaseChannel):
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
|
||||
# Start WebSocket client in a separate thread with reconnect loop
|
||||
def run_ws():
|
||||
while self._running:
|
||||
@@ -327,18 +326,19 @@ class FeishuChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.warning("Feishu WebSocket error: {}", e)
|
||||
if self._running:
|
||||
import time; time.sleep(5)
|
||||
|
||||
import time
|
||||
time.sleep(5)
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
|
||||
|
||||
logger.info("Feishu bot started with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the Feishu bot.
|
||||
@@ -349,7 +349,7 @@ class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
self._running = False
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
try:
|
||||
@@ -360,9 +360,9 @@ class FeishuChannel(BaseChannel):
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||
.build()
|
||||
).build()
|
||||
|
||||
|
||||
response = self._client.im.v1.message_reaction.create(request)
|
||||
|
||||
|
||||
if not response.success():
|
||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||
else:
|
||||
@@ -373,15 +373,15 @@ class FeishuChannel(BaseChannel):
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
if not self._client or not Emoji:
|
||||
return
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
||||
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
@@ -395,12 +395,13 @@ class FeishuChannel(BaseChannel):
|
||||
@staticmethod
|
||||
def _parse_md_table(table_text: str) -> dict | None:
|
||||
"""Parse a markdown table into a Feishu table element."""
|
||||
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()]
|
||||
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
|
||||
if len(lines) < 3:
|
||||
return None
|
||||
split = lambda l: [c.strip() for c in l.strip("|").split("|")]
|
||||
def split(_line: str) -> list[str]:
|
||||
return [c.strip() for c in _line.strip("|").split("|")]
|
||||
headers = split(lines[0])
|
||||
rows = [split(l) for l in lines[2:]]
|
||||
rows = [split(_line) for _line in lines[2:]]
|
||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)]
|
||||
return {
|
||||
@@ -672,7 +673,7 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu message: {}", e)
|
||||
|
||||
|
||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
"""
|
||||
Sync handler for incoming messages (called from WebSocket thread).
|
||||
@@ -680,7 +681,7 @@ class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
|
||||
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
|
||||
@@ -16,24 +16,24 @@ from nanobot.config.schema import Config
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
|
||||
|
||||
Responsibilities:
|
||||
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
||||
- Start/stop channels
|
||||
- Route outbound messages
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
self._init_channels()
|
||||
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels based on config."""
|
||||
|
||||
|
||||
# Telegram channel
|
||||
if self.config.channels.telegram.enabled:
|
||||
try:
|
||||
@@ -46,7 +46,7 @@ class ChannelManager:
|
||||
logger.info("Telegram channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Telegram channel not available: {}", e)
|
||||
|
||||
|
||||
# WhatsApp channel
|
||||
if self.config.channels.whatsapp.enabled:
|
||||
try:
|
||||
@@ -68,7 +68,7 @@ class ChannelManager:
|
||||
logger.info("Discord channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Discord channel not available: {}", e)
|
||||
|
||||
|
||||
# Feishu channel
|
||||
if self.config.channels.feishu.enabled:
|
||||
try:
|
||||
@@ -136,7 +136,7 @@ class ChannelManager:
|
||||
logger.info("QQ channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("QQ channel not available: {}", e)
|
||||
|
||||
|
||||
# Matrix channel
|
||||
if self.config.channels.matrix.enabled:
|
||||
try:
|
||||
@@ -148,7 +148,7 @@ class ChannelManager:
|
||||
logger.info("Matrix channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Matrix channel not available: {}", e)
|
||||
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""Start a channel and log any exceptions."""
|
||||
try:
|
||||
@@ -161,23 +161,23 @@ class ChannelManager:
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
|
||||
# Start outbound dispatcher
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
|
||||
# Start channels
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
|
||||
# Wait for all to complete (they should run forever)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all channels and the dispatcher."""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
|
||||
# Stop dispatcher
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
@@ -185,7 +185,7 @@ class ChannelManager:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# Stop all channels
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
@@ -193,24 +193,24 @@ class ChannelManager:
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
@@ -219,16 +219,16 @@ class ChannelManager:
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""Get a channel by name."""
|
||||
return self.channels.get(name)
|
||||
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get status of all channels."""
|
||||
return {
|
||||
@@ -238,7 +238,7 @@ class ChannelManager:
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""Get list of enabled channel names."""
|
||||
|
||||
@@ -12,10 +12,22 @@ try:
|
||||
import nh3
|
||||
from mistune import create_markdown
|
||||
from nio import (
|
||||
AsyncClient, AsyncClientConfig, ContentRepositoryConfigError,
|
||||
DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse,
|
||||
RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText,
|
||||
RoomSendError, RoomTypingError, SyncError, UploadError,
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
|
||||
@@ -5,11 +5,10 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
@@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, Update, ReplyParameters
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from telegram import BotCommand, ReplyParameters, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
@@ -21,60 +22,60 @@ def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
|
||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||
code_blocks: list[str] = []
|
||||
def save_code_block(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(1))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
inline_codes.append(m.group(1))
|
||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||
|
||||
|
||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||
|
||||
|
||||
# 3. Headers # Title -> just the title text
|
||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 5. Escape HTML special characters
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||
|
||||
|
||||
# 7. Bold **text** or __text__
|
||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||
|
||||
|
||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||
|
||||
|
||||
# 9. Strikethrough ~~text~~
|
||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||
|
||||
|
||||
# 10. Bullet lists - item -> • item
|
||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 11. Restore inline code with HTML tags
|
||||
for i, code in enumerate(inline_codes):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||
|
||||
|
||||
# 12. Restore code blocks with HTML tags
|
||||
for i, code in enumerate(code_blocks):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -101,12 +102,12 @@ def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
|
||||
|
||||
Simple and reliable - no webhook/public IP needed.
|
||||
"""
|
||||
|
||||
|
||||
name = "telegram"
|
||||
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
@@ -114,7 +115,7 @@ class TelegramChannel(BaseChannel):
|
||||
BotCommand("stop", "Stop the current task"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TelegramConfig,
|
||||
@@ -129,15 +130,15 @@ class TelegramChannel(BaseChannel):
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
self._media_group_buffers: dict[str, dict] = {}
|
||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
if not self.config.token:
|
||||
logger.error("Telegram bot token not configured")
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
@@ -145,51 +146,51 @@ class TelegramChannel(BaseChannel):
|
||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
MessageHandler(
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
self._on_message
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.info("Starting Telegram bot (polling mode)...")
|
||||
|
||||
|
||||
# Initialize and start polling
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
|
||||
try:
|
||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||
logger.debug("Telegram bot commands registered")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register bot commands: {}", e)
|
||||
|
||||
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
)
|
||||
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Telegram bot."""
|
||||
self._running = False
|
||||
|
||||
|
||||
# Cancel all typing indicators
|
||||
for chat_id in list(self._typing_tasks):
|
||||
self._stop_typing(chat_id)
|
||||
@@ -198,14 +199,14 @@ class TelegramChannel(BaseChannel):
|
||||
task.cancel()
|
||||
self._media_group_tasks.clear()
|
||||
self._media_group_buffers.clear()
|
||||
|
||||
|
||||
if self._app:
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
self._app = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_media_type(path: str) -> str:
|
||||
"""Guess media type from file extension."""
|
||||
@@ -253,7 +254,7 @@ class TelegramChannel(BaseChannel):
|
||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||
with open(media_path, 'rb') as f:
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
@@ -272,8 +273,8 @@ class TelegramChannel(BaseChannel):
|
||||
try:
|
||||
html = _markdown_to_telegram_html(chunk)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=html,
|
||||
chat_id=chat_id,
|
||||
text=html,
|
||||
parse_mode="HTML",
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
@@ -281,13 +282,13 @@ class TelegramChannel(BaseChannel):
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
chat_id=chat_id,
|
||||
text=chunk,
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
if not update.message or not update.effective_user:
|
||||
@@ -326,34 +327,34 @@ class TelegramChannel(BaseChannel):
|
||||
chat_id=str(update.message.chat_id),
|
||||
content=update.message.text,
|
||||
)
|
||||
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
|
||||
# Text content
|
||||
if message.text:
|
||||
content_parts.append(message.text)
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
@@ -366,23 +367,23 @@ class TelegramChannel(BaseChannel):
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
||||
|
||||
|
||||
# Save to workspace/media/
|
||||
from pathlib import Path
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
|
||||
# Handle voice transcription
|
||||
if media_type == "voice" or media_type == "audio":
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
@@ -395,16 +396,16 @@ class TelegramChannel(BaseChannel):
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
|
||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||
@@ -428,10 +429,10 @@ class TelegramChannel(BaseChannel):
|
||||
if key not in self._media_group_tasks:
|
||||
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||
return
|
||||
|
||||
|
||||
# Start typing indicator before processing
|
||||
self._start_typing(str_chat_id)
|
||||
|
||||
|
||||
# Forward to the message bus
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
@@ -446,7 +447,7 @@ class TelegramChannel(BaseChannel):
|
||||
"is_group": message.chat.type != "private"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _flush_media_group(self, key: str) -> None:
|
||||
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||
try:
|
||||
@@ -467,13 +468,13 @@ class TelegramChannel(BaseChannel):
|
||||
# Cancel any existing typing task for this chat
|
||||
self._stop_typing(chat_id)
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||
|
||||
|
||||
def _stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
@@ -484,7 +485,7 @@ class TelegramChannel(BaseChannel):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
@@ -498,6 +499,6 @@ class TelegramChannel(BaseChannel):
|
||||
}
|
||||
if mime_type in ext_map:
|
||||
return ext_map[mime_type]
|
||||
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
return type_map.get(media_type, "")
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -29,17 +28,17 @@ class WhatsAppChannel(BaseChannel):
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||
import websockets
|
||||
|
||||
|
||||
bridge_url = self.config.bridge_url
|
||||
|
||||
|
||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||
|
||||
|
||||
self._running = True
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
@@ -49,40 +48,40 @@ class WhatsAppChannel(BaseChannel):
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
|
||||
# Listen for messages
|
||||
async for message in ws:
|
||||
try:
|
||||
await self._handle_bridge_message(message)
|
||||
except Exception as e:
|
||||
logger.error("Error handling bridge message: {}", e)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||
|
||||
|
||||
if self._running:
|
||||
logger.info("Reconnecting in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WhatsApp channel."""
|
||||
self._running = False
|
||||
self._connected = False
|
||||
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WhatsApp."""
|
||||
if not self._ws or not self._connected:
|
||||
logger.warning("WhatsApp bridge not connected")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "send",
|
||||
@@ -92,7 +91,7 @@ class WhatsAppChannel(BaseChannel):
|
||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error("Error sending WhatsApp message: {}", e)
|
||||
|
||||
|
||||
async def _handle_bridge_message(self, raw: str) -> None:
|
||||
"""Handle a message from the bridge."""
|
||||
try:
|
||||
@@ -100,9 +99,9 @@ class WhatsAppChannel(BaseChannel):
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||
return
|
||||
|
||||
|
||||
msg_type = data.get("type")
|
||||
|
||||
|
||||
if msg_type == "message":
|
||||
# Incoming message from WhatsApp
|
||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||
@@ -139,20 +138,20 @@ class WhatsAppChannel(BaseChannel):
|
||||
"is_group": data.get("isGroup", False)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
elif msg_type == "status":
|
||||
# Connection status update
|
||||
status = data.get("status")
|
||||
logger.info("WhatsApp status: {}", status)
|
||||
|
||||
|
||||
if status == "connected":
|
||||
self._connected = True
|
||||
elif status == "disconnected":
|
||||
self._connected = False
|
||||
|
||||
|
||||
elif msg_type == "qr":
|
||||
# QR code for authentication
|
||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||
|
||||
|
||||
elif msg_type == "error":
|
||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||
|
||||
Reference in New Issue
Block a user