merge main into pr-532 and keep qq msg_seq/startup behavior while adding group @message support with regression tests
This commit is contained in:
@@ -12,17 +12,17 @@ from nanobot.bus.queue import MessageBus
|
||||
class BaseChannel(ABC):
|
||||
"""
|
||||
Abstract base class for chat channel implementations.
|
||||
|
||||
|
||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||
to integrate with the nanobot message bus.
|
||||
"""
|
||||
|
||||
|
||||
name: str = "base"
|
||||
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
|
||||
Args:
|
||||
config: Channel-specific configuration.
|
||||
bus: The message bus for communication.
|
||||
@@ -30,97 +30,89 @@ class BaseChannel(ABC):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the channel and begin listening for messages.
|
||||
|
||||
|
||||
This should be a long-running async task that:
|
||||
1. Connects to the chat platform
|
||||
2. Listens for incoming messages
|
||||
3. Forwards messages to the bus via _handle_message()
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""
|
||||
Send a message through this channel.
|
||||
|
||||
|
||||
Args:
|
||||
msg: The message to send.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""
|
||||
Check if a sender is allowed to use this bot.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise.
|
||||
"""
|
||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
|
||||
# If no allow list, allow everyone
|
||||
if not allow_list:
|
||||
logger.warning("{}: allow_from is empty — all access denied", self.name)
|
||||
return False
|
||||
if "*" in allow_list:
|
||||
return True
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str in allow_list:
|
||||
return True
|
||||
if "|" in sender_str:
|
||||
for part in sender_str.split("|"):
|
||||
if part and part in allow_list:
|
||||
return True
|
||||
return False
|
||||
|
||||
return sender_str in allow_list or any(
|
||||
p in allow_list for p in sender_str.split("|") if p
|
||||
)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
content: Message text content.
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
"""
|
||||
if not self.is_allowed(sender_id):
|
||||
logger.warning(
|
||||
f"Access denied for sender {sender_id} on channel {self.name}. "
|
||||
f"Add them to allowFrom list in config to grant access."
|
||||
"Access denied for sender {} on channel {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
chat_id=str(chat_id),
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata=metadata or {}
|
||||
metadata=metadata or {},
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the channel is running."""
|
||||
|
||||
@@ -2,11 +2,15 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from loguru import logger
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@@ -15,11 +19,11 @@ from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
try:
|
||||
from dingtalk_stream import (
|
||||
DingTalkStreamClient,
|
||||
Credential,
|
||||
AckMessage,
|
||||
CallbackHandler,
|
||||
CallbackMessage,
|
||||
AckMessage,
|
||||
Credential,
|
||||
DingTalkStreamClient,
|
||||
)
|
||||
from dingtalk_stream.chatbot import ChatbotMessage
|
||||
|
||||
@@ -58,19 +62,32 @@ class NanobotDingTalkHandler(CallbackHandler):
|
||||
|
||||
if not content:
|
||||
logger.warning(
|
||||
f"Received empty or unsupported message type: {chatbot_msg.message_type}"
|
||||
"Received empty or unsupported message type: {}",
|
||||
chatbot_msg.message_type,
|
||||
)
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||
|
||||
logger.info(f"Received DingTalk message from {sender_name} ({sender_id}): {content}")
|
||||
conversation_type = message.data.get("conversationType")
|
||||
conversation_id = (
|
||||
message.data.get("conversationId")
|
||||
or message.data.get("openConversationId")
|
||||
)
|
||||
|
||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||
|
||||
# Forward to Nanobot via _on_message (non-blocking).
|
||||
# Store reference to prevent GC before task completes.
|
||||
task = asyncio.create_task(
|
||||
self.channel._on_message(content, sender_id, sender_name)
|
||||
self.channel._on_message(
|
||||
content,
|
||||
sender_id,
|
||||
sender_name,
|
||||
conversation_type,
|
||||
conversation_id,
|
||||
)
|
||||
)
|
||||
self.channel._background_tasks.add(task)
|
||||
task.add_done_callback(self.channel._background_tasks.discard)
|
||||
@@ -78,7 +95,7 @@ class NanobotDingTalkHandler(CallbackHandler):
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing DingTalk message: {e}")
|
||||
logger.error("Error processing DingTalk message: {}", e)
|
||||
# Return OK to avoid retry loop from DingTalk server
|
||||
return AckMessage.STATUS_OK, "Error"
|
||||
|
||||
@@ -90,11 +107,14 @@ class DingTalkChannel(BaseChannel):
|
||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||
|
||||
Note: Currently only supports private (1:1) chat. Group messages are
|
||||
received but replies are sent back as private messages to the sender.
|
||||
Supports both private (1:1) and group chats.
|
||||
Group chat_id is stored with a "group:" prefix to route replies back.
|
||||
"""
|
||||
|
||||
name = "dingtalk"
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||
|
||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -126,7 +146,8 @@ class DingTalkChannel(BaseChannel):
|
||||
self._http = httpx.AsyncClient()
|
||||
|
||||
logger.info(
|
||||
f"Initializing DingTalk Stream Client with Client ID: {self.config.client_id}..."
|
||||
"Initializing DingTalk Stream Client with Client ID: {}...",
|
||||
self.config.client_id,
|
||||
)
|
||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||
self._client = DingTalkStreamClient(credential)
|
||||
@@ -142,13 +163,13 @@ class DingTalkChannel(BaseChannel):
|
||||
try:
|
||||
await self._client.start()
|
||||
except Exception as e:
|
||||
logger.warning(f"DingTalk stream error: {e}")
|
||||
logger.warning("DingTalk stream error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to start DingTalk channel: {e}")
|
||||
logger.exception("Failed to start DingTalk channel: {}", e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DingTalk bot."""
|
||||
@@ -186,60 +207,265 @@ class DingTalkChannel(BaseChannel):
|
||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||
return self._access_token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get DingTalk access token: {e}")
|
||||
logger.error("Failed to get DingTalk access token: {}", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_http_url(value: str) -> bool:
|
||||
return urlparse(value).scheme in ("http", "https")
|
||||
|
||||
def _guess_upload_type(self, media_ref: str) -> str:
|
||||
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||
if ext in self._IMAGE_EXTS: return "image"
|
||||
if ext in self._AUDIO_EXTS: return "voice"
|
||||
if ext in self._VIDEO_EXTS: return "video"
|
||||
return "file"
|
||||
|
||||
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||
name = os.path.basename(urlparse(media_ref).path)
|
||||
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||
|
||||
async def _read_media_bytes(
|
||||
self,
|
||||
media_ref: str,
|
||||
) -> tuple[bytes | None, str | None, str | None]:
|
||||
if not media_ref:
|
||||
return None, None, None
|
||||
|
||||
if self._is_http_url(media_ref):
|
||||
if not self._http:
|
||||
return None, None, None
|
||||
try:
|
||||
resp = await self._http.get(media_ref, follow_redirects=True)
|
||||
if resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"DingTalk media download failed status={} ref={}",
|
||||
resp.status_code,
|
||||
media_ref,
|
||||
)
|
||||
return None, None, None
|
||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
return resp.content, filename, content_type or None
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
|
||||
try:
|
||||
if media_ref.startswith("file://"):
|
||||
parsed = urlparse(media_ref)
|
||||
local_path = Path(unquote(parsed.path))
|
||||
else:
|
||||
local_path = Path(os.path.expanduser(media_ref))
|
||||
if not local_path.is_file():
|
||||
logger.warning("DingTalk media file not found: {}", local_path)
|
||||
return None, None, None
|
||||
data = await asyncio.to_thread(local_path.read_bytes)
|
||||
content_type = mimetypes.guess_type(local_path.name)[0]
|
||||
return data, local_path.name, content_type
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
|
||||
async def _upload_media(
|
||||
self,
|
||||
token: str,
|
||||
data: bytes,
|
||||
media_type: str,
|
||||
filename: str,
|
||||
content_type: str | None,
|
||||
) -> str | None:
|
||||
if not self._http:
|
||||
return None
|
||||
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
|
||||
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
files = {"media": (filename, data, mime)}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, files=files)
|
||||
text = resp.text
|
||||
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||
if resp.status_code >= 400:
|
||||
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
||||
return None
|
||||
errcode = result.get("errcode", 0)
|
||||
if errcode != 0:
|
||||
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
||||
return None
|
||||
sub = result.get("result") or {}
|
||||
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
||||
if not media_id:
|
||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||
return None
|
||||
return str(media_id)
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||
return None
|
||||
|
||||
async def _send_batch_message(
|
||||
self,
|
||||
token: str,
|
||||
chat_id: str,
|
||||
msg_key: str,
|
||||
msg_param: dict[str, Any],
|
||||
) -> bool:
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return False
|
||||
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
if chat_id.startswith("group:"):
|
||||
# Group chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"openConversationId": chat_id[6:], # Remove "group:" prefix,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
else:
|
||||
# Private chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [chat_id],
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=payload, headers=headers)
|
||||
body = resp.text
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||
return False
|
||||
try: result = resp.json()
|
||||
except Exception: result = {}
|
||||
errcode = result.get("errcode")
|
||||
if errcode not in (None, 0):
|
||||
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||
return False
|
||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||
return False
|
||||
|
||||
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
||||
return await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleMarkdown",
|
||||
{"text": content, "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
|
||||
media_ref = (media_ref or "").strip()
|
||||
if not media_ref:
|
||||
return True
|
||||
|
||||
upload_type = self._guess_upload_type(media_ref)
|
||||
if upload_type == "image" and self._is_http_url(media_ref):
|
||||
ok = await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleImageMsg",
|
||||
{"photoURL": media_ref},
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
|
||||
|
||||
data, filename, content_type = await self._read_media_bytes(media_ref)
|
||||
if not data:
|
||||
logger.error("DingTalk media read failed: {}", media_ref)
|
||||
return False
|
||||
|
||||
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||
if not file_type:
|
||||
guessed = mimetypes.guess_extension(content_type or "")
|
||||
file_type = (guessed or ".bin").lstrip(".")
|
||||
if file_type == "jpeg":
|
||||
file_type = "jpg"
|
||||
|
||||
media_id = await self._upload_media(
|
||||
token=token,
|
||||
data=data,
|
||||
media_type=upload_type,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
if upload_type == "image":
|
||||
# Verified in production: sampleImageMsg accepts media_id in photoURL.
|
||||
ok = await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleImageMsg",
|
||||
{"photoURL": media_id},
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
|
||||
|
||||
return await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleFile",
|
||||
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
|
||||
)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through DingTalk."""
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return
|
||||
|
||||
# oToMessages/batchSend: sends to individual users (private chat)
|
||||
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
|
||||
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
||||
if ok:
|
||||
continue
|
||||
logger.error("DingTalk media send failed for {}", media_ref)
|
||||
# Send visible fallback so failures are observable by the user.
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
await self._send_markdown_text(
|
||||
token,
|
||||
msg.chat_id,
|
||||
f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
data = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [msg.chat_id], # chat_id is the user's staffId
|
||||
"msgKey": "sampleMarkdown",
|
||||
"msgParam": json.dumps({
|
||||
"text": msg.content,
|
||||
"title": "Nanobot Reply",
|
||||
}),
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
logger.error(f"DingTalk send failed: {resp.text}")
|
||||
else:
|
||||
logger.debug(f"DingTalk message sent to {msg.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending DingTalk message: {e}")
|
||||
|
||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||
async def _on_message(
|
||||
self,
|
||||
content: str,
|
||||
sender_id: str,
|
||||
sender_name: str,
|
||||
conversation_type: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||
|
||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||
permission checks before publishing to the bus.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"DingTalk inbound: {content} from {sender_name}")
|
||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||
is_group = conversation_type == "2" and conversation_id
|
||||
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender_id, # For private chat, chat_id == sender_id
|
||||
chat_id=chat_id,
|
||||
content=str(content),
|
||||
metadata={
|
||||
"sender_name": sender_name,
|
||||
"platform": "dingtalk",
|
||||
"conversation_type": conversation_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error publishing DingTalk message: {e}")
|
||||
logger.error("Error publishing DingTalk message: {}", e)
|
||||
|
||||
@@ -13,10 +13,11 @@ from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
@@ -32,6 +33,7 @@ class DiscordChannel(BaseChannel):
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
@@ -51,7 +53,7 @@ class DiscordChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Discord gateway error: {e}")
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
@@ -73,40 +75,118 @@ class DiscordChannel(BaseChannel):
|
||||
self._http = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API."""
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
payload: dict[str, Any] = {"content": msg.content}
|
||||
|
||||
if msg.reply_to:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
|
||||
try:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning(f"Discord rate limited, retrying in {retry_after}s")
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error(f"Error sending Discord message: {e}")
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
@@ -116,7 +196,7 @@ class DiscordChannel(BaseChannel):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from Discord gateway: {raw[:100]}")
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
@@ -134,6 +214,10 @@ class DiscordChannel(BaseChannel):
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
@@ -175,7 +259,7 @@ class DiscordChannel(BaseChannel):
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning(f"Discord heartbeat failed: {e}")
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
@@ -190,6 +274,7 @@ class DiscordChannel(BaseChannel):
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
@@ -197,6 +282,11 @@ class DiscordChannel(BaseChannel):
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
@@ -219,7 +309,7 @@ class DiscordChannel(BaseChannel):
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download Discord attachment: {e}")
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
@@ -233,11 +323,32 @@ class DiscordChannel(BaseChannel):
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": payload.get("guild_id"),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
await self._stop_typing(channel_id)
|
||||
@@ -248,8 +359,11 @@ class DiscordChannel(BaseChannel):
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
@@ -94,7 +94,7 @@ class EmailChannel(BaseChannel):
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Email polling error: {e}")
|
||||
logger.error("Email polling error: {}", e)
|
||||
|
||||
await asyncio.sleep(poll_seconds)
|
||||
|
||||
@@ -108,11 +108,6 @@ class EmailChannel(BaseChannel):
|
||||
logger.warning("Skip email send: consent_granted is false")
|
||||
return
|
||||
|
||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||
if not self.config.auto_reply_enabled and not force_send:
|
||||
logger.info("Skip automatic email reply: auto_reply_enabled is false")
|
||||
return
|
||||
|
||||
if not self.config.smtp_host:
|
||||
logger.warning("Email channel SMTP host not configured")
|
||||
return
|
||||
@@ -122,6 +117,15 @@ class EmailChannel(BaseChannel):
|
||||
logger.warning("Email channel missing recipient address")
|
||||
return
|
||||
|
||||
# Determine if this is a reply (recipient has sent us an email before)
|
||||
is_reply = to_addr in self._last_subject_by_chat
|
||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||
|
||||
# autoReplyEnabled only controls automatic replies, not proactive sends
|
||||
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
||||
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
|
||||
return
|
||||
|
||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||
subject = self._reply_subject(base_subject)
|
||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||
@@ -143,7 +147,7 @@ class EmailChannel(BaseChannel):
|
||||
try:
|
||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email to {to_addr}: {e}")
|
||||
logger.error("Error sending email to {}: {}", to_addr, e)
|
||||
raise
|
||||
|
||||
def _validate_config(self) -> bool:
|
||||
@@ -162,7 +166,7 @@ class EmailChannel(BaseChannel):
|
||||
missing.append("smtp_password")
|
||||
|
||||
if missing:
|
||||
logger.error(f"Email channel not configured, missing: {', '.join(missing)}")
|
||||
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -304,7 +308,8 @@ class EmailChannel(BaseChannel):
|
||||
self._processed_uids.add(uid)
|
||||
# mark_seen is the primary dedup; this set is a safety net
|
||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||
self._processed_uids.clear()
|
||||
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||
|
||||
if mark_seen:
|
||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -12,32 +12,28 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
|
||||
|
||||
Responsibilities:
|
||||
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
||||
- Start/stop channels
|
||||
- Route outbound messages
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus, session_manager: "SessionManager | None" = None):
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self.session_manager = session_manager
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
self._init_channels()
|
||||
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels based on config."""
|
||||
|
||||
|
||||
# Telegram channel
|
||||
if self.config.channels.telegram.enabled:
|
||||
try:
|
||||
@@ -46,12 +42,11 @@ class ChannelManager:
|
||||
self.config.channels.telegram,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
session_manager=self.session_manager,
|
||||
)
|
||||
logger.info("Telegram channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Telegram channel not available: {e}")
|
||||
|
||||
logger.warning("Telegram channel not available: {}", e)
|
||||
|
||||
# WhatsApp channel
|
||||
if self.config.channels.whatsapp.enabled:
|
||||
try:
|
||||
@@ -61,7 +56,7 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("WhatsApp channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"WhatsApp channel not available: {e}")
|
||||
logger.warning("WhatsApp channel not available: {}", e)
|
||||
|
||||
# Discord channel
|
||||
if self.config.channels.discord.enabled:
|
||||
@@ -72,18 +67,19 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("Discord channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Discord channel not available: {e}")
|
||||
|
||||
logger.warning("Discord channel not available: {}", e)
|
||||
|
||||
# Feishu channel
|
||||
if self.config.channels.feishu.enabled:
|
||||
try:
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
self.channels["feishu"] = FeishuChannel(
|
||||
self.config.channels.feishu, self.bus
|
||||
self.config.channels.feishu, self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Feishu channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Feishu channel not available: {e}")
|
||||
logger.warning("Feishu channel not available: {}", e)
|
||||
|
||||
# Mochat channel
|
||||
if self.config.channels.mochat.enabled:
|
||||
@@ -95,7 +91,7 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("Mochat channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Mochat channel not available: {e}")
|
||||
logger.warning("Mochat channel not available: {}", e)
|
||||
|
||||
# DingTalk channel
|
||||
if self.config.channels.dingtalk.enabled:
|
||||
@@ -106,7 +102,7 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("DingTalk channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"DingTalk channel not available: {e}")
|
||||
logger.warning("DingTalk channel not available: {}", e)
|
||||
|
||||
# Email channel
|
||||
if self.config.channels.email.enabled:
|
||||
@@ -117,7 +113,7 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("Email channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Email channel not available: {e}")
|
||||
logger.warning("Email channel not available: {}", e)
|
||||
|
||||
# Slack channel
|
||||
if self.config.channels.slack.enabled:
|
||||
@@ -128,7 +124,7 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("Slack channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Slack channel not available: {e}")
|
||||
logger.warning("Slack channel not available: {}", e)
|
||||
|
||||
# QQ channel
|
||||
if self.config.channels.qq.enabled:
|
||||
@@ -140,37 +136,59 @@ class ChannelManager:
|
||||
)
|
||||
logger.info("QQ channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"QQ channel not available: {e}")
|
||||
|
||||
logger.warning("QQ channel not available: {}", e)
|
||||
|
||||
# Matrix channel
|
||||
if self.config.channels.matrix.enabled:
|
||||
try:
|
||||
from nanobot.channels.matrix import MatrixChannel
|
||||
self.channels["matrix"] = MatrixChannel(
|
||||
self.config.channels.matrix,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("Matrix channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Matrix channel not available: {}", e)
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
raise SystemExit(
|
||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||
)
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""Start a channel and log any exceptions."""
|
||||
try:
|
||||
await channel.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start channel {name}: {e}")
|
||||
logger.error("Failed to start channel {}: {}", name, e)
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""Start all channels and the outbound dispatcher."""
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
|
||||
# Start outbound dispatcher
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
|
||||
# Start channels
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info(f"Starting {name} channel...")
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
|
||||
# Wait for all to complete (they should run forever)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all channels and the dispatcher."""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
|
||||
# Stop dispatcher
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
@@ -178,44 +196,50 @@ class ChannelManager:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# Stop all channels
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info(f"Stopped {name} channel")
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping {name}: {e}")
|
||||
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
await channel.send(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to {msg.channel}: {e}")
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
else:
|
||||
logger.warning(f"Unknown channel: {msg.channel}")
|
||||
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""Get a channel by name."""
|
||||
return self.channels.get(name)
|
||||
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get status of all channels."""
|
||||
return {
|
||||
@@ -225,7 +249,7 @@ class ChannelManager:
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""Get list of enabled channel names."""
|
||||
|
||||
699
nanobot/channels/matrix.py
Normal file
699
nanobot/channels/matrix.py
Normal file
@@ -0,0 +1,699 @@
|
||||
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import nh3
|
||||
from mistune import create_markdown
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
|
||||
) from e
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.loader import get_data_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
|
||||
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
|
||||
_ATTACH_MARKER = "[attachment: {}]"
|
||||
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
|
||||
_ATTACH_FAILED = "[attachment: {} - download failed]"
|
||||
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
|
||||
_DEFAULT_ATTACH_NAME = "attachment"
|
||||
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
|
||||
|
||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||
|
||||
MATRIX_MARKDOWN = create_markdown(
|
||||
escape=True,
|
||||
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||
)
|
||||
|
||||
MATRIX_ALLOWED_HTML_TAGS = {
|
||||
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
|
||||
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
|
||||
"caption", "sup", "sub", "img",
|
||||
}
|
||||
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
|
||||
"a": {"href"}, "code": {"class"}, "ol": {"start"},
|
||||
"img": {"src", "alt", "title", "width", "height"},
|
||||
}
|
||||
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
|
||||
|
||||
|
||||
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
|
||||
"""Filter attribute values to a safe Matrix-compatible subset."""
|
||||
if tag == "a" and attr == "href":
|
||||
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
|
||||
if tag == "img" and attr == "src":
|
||||
return value if value.lower().startswith("mxc://") else None
|
||||
if tag == "code" and attr == "class":
|
||||
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
|
||||
return " ".join(classes) if classes else None
|
||||
return value
|
||||
|
||||
|
||||
MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
tags=MATRIX_ALLOWED_HTML_TAGS,
|
||||
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
|
||||
attribute_filter=_filter_matrix_html_attribute,
|
||||
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
|
||||
strip_comments=True,
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
try:
|
||||
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
|
||||
except Exception:
|
||||
return None
|
||||
if not formatted:
|
||||
return None
|
||||
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
|
||||
if formatted.startswith("<p>") and formatted.endswith("</p>"):
|
||||
inner = formatted[3:-4]
|
||||
if "<" not in inner and ">" not in inner:
|
||||
return None
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
return content
|
||||
|
||||
|
||||
class _NioLoguruHandler(logging.Handler):
|
||||
"""Route matrix-nio stdlib logs into Loguru."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame, depth = frame.f_back, depth + 1
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def _configure_nio_logging_bridge() -> None:
|
||||
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
||||
nio_logger = logging.getLogger("nio")
|
||||
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
||||
nio_logger.handlers = [_NioLoguruHandler()]
|
||||
nio_logger.propagate = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
"""Matrix (Element) channel using long-polling sync."""
|
||||
|
||||
name = "matrix"
|
||||
|
||||
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
|
||||
workspace: Path | None = None):
|
||||
super().__init__(config, bus)
|
||||
self.client: AsyncClient | None = None
|
||||
self._sync_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._restrict_to_workspace = restrict_to_workspace
|
||||
self._workspace = workspace.expanduser().resolve() if workspace else None
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
self._running = True
|
||||
_configure_nio_logging_bridge()
|
||||
|
||||
store_path = get_data_dir() / "matrix-store"
|
||||
store_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = AsyncClient(
|
||||
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||
store_path=store_path,
|
||||
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||
)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_response_callbacks()
|
||||
|
||||
if not self.config.e2ee_enabled:
|
||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||
|
||||
if self.config.device_id:
|
||||
try:
|
||||
self.client.load_store()
|
||||
except Exception:
|
||||
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||
else:
|
||||
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Matrix channel with graceful sync shutdown."""
|
||||
self._running = False
|
||||
for room_id in list(self._typing_tasks):
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
if self.client:
|
||||
self.client.stop_sync_forever()
|
||||
if self._sync_task:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(self._sync_task),
|
||||
timeout=self.config.sync_stop_grace_seconds)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||
"""Check path is inside workspace (when restriction enabled)."""
|
||||
if not self._restrict_to_workspace or not self._workspace:
|
||||
return True
|
||||
try:
|
||||
path.resolve(strict=False).relative_to(self._workspace)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
|
||||
"""Deduplicate and resolve outbound attachment paths."""
|
||||
seen: set[str] = set()
|
||||
candidates: list[Path] = []
|
||||
for raw in media:
|
||||
if not isinstance(raw, str) or not raw.strip():
|
||||
continue
|
||||
path = Path(raw.strip()).expanduser()
|
||||
try:
|
||||
key = str(path.resolve(strict=False))
|
||||
except OSError:
|
||||
key = str(path)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
candidates.append(path)
|
||||
return candidates
|
||||
|
||||
@staticmethod
|
||||
def _build_outbound_attachment_content(
|
||||
*, filename: str, mime: str, size_bytes: int,
|
||||
mxc_url: str, encryption_info: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Matrix content payload for an uploaded file/image/audio/video."""
|
||||
prefix = mime.split("/")[0]
|
||||
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
|
||||
content: dict[str, Any] = {
|
||||
"msgtype": msgtype, "body": filename, "filename": filename,
|
||||
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
|
||||
}
|
||||
if encryption_info:
|
||||
content["file"] = {**encryption_info, "url": mxc_url}
|
||||
else:
|
||||
content["url"] = mxc_url
|
||||
return content
|
||||
|
||||
def _is_encrypted_room(self, room_id: str) -> bool:
|
||||
if not self.client:
|
||||
return False
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
if self._server_upload_limit_checked:
|
||||
return self._server_upload_limit_bytes
|
||||
self._server_upload_limit_checked = True
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
response = await self.client.content_repository_config()
|
||||
except Exception:
|
||||
return None
|
||||
upload_size = getattr(response, "upload_size", None)
|
||||
if isinstance(upload_size, int) and upload_size > 0:
|
||||
self._server_upload_limit_bytes = upload_size
|
||||
return upload_size
|
||||
return None
|
||||
|
||||
async def _effective_media_limit_bytes(self) -> int:
|
||||
"""min(local config, server advertised) — 0 blocks all uploads."""
|
||||
local_limit = max(int(self.config.max_media_bytes), 0)
|
||||
server_limit = await self._resolve_server_upload_limit_bytes()
|
||||
if server_limit is None:
|
||||
return local_limit
|
||||
return min(local_limit, server_limit) if local_limit else 0
|
||||
|
||||
async def _upload_and_send_attachment(
|
||||
self, room_id: str, path: Path, limit_bytes: int,
|
||||
relates_to: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
|
||||
if not self.client:
|
||||
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
|
||||
|
||||
resolved = path.expanduser().resolve(strict=False)
|
||||
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
|
||||
fail = _ATTACH_UPLOAD_FAILED.format(filename)
|
||||
|
||||
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
|
||||
return fail
|
||||
try:
|
||||
size_bytes = resolved.stat().st_size
|
||||
except OSError:
|
||||
return fail
|
||||
if limit_bytes <= 0 or size_bytes > limit_bytes:
|
||||
return _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
|
||||
try:
|
||||
with resolved.open("rb") as f:
|
||||
upload_result = await self.client.upload(
|
||||
f, content_type=mime, filename=filename,
|
||||
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
|
||||
filesize=size_bytes,
|
||||
)
|
||||
except Exception:
|
||||
return fail
|
||||
|
||||
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
|
||||
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
|
||||
if isinstance(upload_response, UploadError):
|
||||
return fail
|
||||
mxc_url = getattr(upload_response, "content_uri", None)
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return fail
|
||||
|
||||
content = self._build_outbound_attachment_content(
|
||||
filename=filename, mime=mime, size_bytes=size_bytes,
|
||||
mxc_url=mxc_url, encryption_info=encryption_info,
|
||||
)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
try:
|
||||
await self._send_room_content(room_id, content)
|
||||
except Exception:
|
||||
return fail
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound content; clear typing for non-progress messages."""
|
||||
if not self.client:
|
||||
return
|
||||
text = msg.content or ""
|
||||
candidates = self._collect_outbound_media_candidates(msg.media)
|
||||
relates_to = self._build_thread_relates_to(msg.metadata)
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
try:
|
||||
failures: list[str] = []
|
||||
if candidates:
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
for path in candidates:
|
||||
if fail := await self._upload_and_send_attachment(
|
||||
room_id=msg.chat_id,
|
||||
path=path,
|
||||
limit_bytes=limit_bytes,
|
||||
relates_to=relates_to,
|
||||
):
|
||||
failures.append(fail)
|
||||
if failures:
|
||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||
if text or not candidates:
|
||||
content = _build_matrix_text_content(text)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
await self._send_room_content(msg.chat_id, content)
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||
|
||||
def _log_response_error(self, label: str, response: Any) -> None:
|
||||
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||
code = getattr(response, "status_code", None)
|
||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||
is_fatal = is_auth or getattr(response, "soft_logout", False)
|
||||
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
||||
|
||||
async def _on_sync_error(self, response: SyncError) -> None:
|
||||
self._log_response_error("sync", response)
|
||||
|
||||
async def _on_join_error(self, response: JoinError) -> None:
|
||||
self._log_response_error("join", response)
|
||||
|
||||
async def _on_send_error(self, response: RoomSendError) -> None:
|
||||
self._log_response_error("send", response)
|
||||
|
||||
async def _set_typing(self, room_id: str, typing: bool) -> None:
|
||||
"""Best-effort typing indicator update."""
|
||||
if not self.client:
|
||||
return
|
||||
try:
|
||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||
if isinstance(response, RoomTypingError):
|
||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
await self._set_typing(room_id, True)
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
async def loop() -> None:
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||
await self._set_typing(room_id, True)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||
|
||||
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||
if task := self._typing_tasks.pop(room_id, None):
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if clear_typing:
|
||||
await self._set_typing(room_id, False)
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
await self.client.sync_forever(timeout=30000, full_state=True)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||
if self.is_allowed(event.sender):
|
||||
await self.client.join(room.room_id)
|
||||
|
||||
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||
count = getattr(room, "member_count", None)
|
||||
return isinstance(count, int) and count <= 2
|
||||
|
||||
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
|
||||
"""Check m.mentions payload for bot mention."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return False
|
||||
mentions = (source.get("content") or {}).get("m.mentions")
|
||||
if not isinstance(mentions, dict):
|
||||
return False
|
||||
user_ids = mentions.get("user_ids")
|
||||
if isinstance(user_ids, list) and self.config.user_id in user_ids:
|
||||
return True
|
||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||
|
||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||
"""Apply sender and room policy checks."""
|
||||
if not self.is_allowed(event.sender):
|
||||
return False
|
||||
if self._is_direct_room(room):
|
||||
return True
|
||||
policy = self.config.group_policy
|
||||
if policy == "open":
|
||||
return True
|
||||
if policy == "allowlist":
|
||||
return room.room_id in (self.config.group_allow_from or [])
|
||||
if policy == "mention":
|
||||
return self._is_bot_mentioned(event)
|
||||
return False
|
||||
|
||||
def _media_dir(self) -> Path:
|
||||
d = get_data_dir() / "media" / "matrix"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return {}
|
||||
content = source.get("content")
|
||||
return content if isinstance(content, dict) else {}
|
||||
|
||||
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||
relates_to = self._event_source_content(event).get("m.relates_to")
|
||||
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
|
||||
return None
|
||||
root_id = relates_to.get("event_id")
|
||||
return root_id if isinstance(root_id, str) and root_id else None
|
||||
|
||||
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||
if not (root_id := self._event_thread_root_id(event)):
|
||||
return None
|
||||
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
|
||||
meta["thread_reply_to_event_id"] = reply_to
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not metadata:
|
||||
return None
|
||||
root_id = metadata.get("thread_root_event_id")
|
||||
if not isinstance(root_id, str) or not root_id:
|
||||
return None
|
||||
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||
if not isinstance(reply_to, str) or not reply_to:
|
||||
return None
|
||||
return {"rel_type": "m.thread", "event_id": root_id,
|
||||
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
|
||||
|
||||
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||
msgtype = self._event_source_content(event).get("msgtype")
|
||||
return _MSGTYPE_MAP.get(msgtype, "file")
|
||||
|
||||
@staticmethod
|
||||
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||
return (isinstance(getattr(event, "key", None), dict)
|
||||
and isinstance(getattr(event, "hashes", None), dict)
|
||||
and isinstance(getattr(event, "iv", None), str))
|
||||
|
||||
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
size = info.get("size") if isinstance(info, dict) else None
|
||||
return size if isinstance(size, int) and size >= 0 else None
|
||||
|
||||
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
|
||||
return m
|
||||
m = getattr(event, "mimetype", None)
|
||||
return m if isinstance(m, str) and m else None
|
||||
|
||||
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||
body = getattr(event, "body", None)
|
||||
if isinstance(body, str) and body.strip():
|
||||
if candidate := safe_filename(Path(body).name):
|
||||
return candidate
|
||||
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
|
||||
|
||||
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
|
||||
filename: str, mime: str | None) -> Path:
|
||||
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
|
||||
suffix = Path(safe_name).suffix
|
||||
if not suffix and mime:
|
||||
if guessed := mimetypes.guess_extension(mime, strict=False):
|
||||
safe_name, suffix = f"{safe_name}{guessed}", guessed
|
||||
stem = (Path(safe_name).stem or attachment_type)[:72]
|
||||
suffix = suffix[:16]
|
||||
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
|
||||
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||
|
||||
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
||||
if not self.client:
|
||||
return None
|
||||
response = await self.client.download(mxc=mxc_url)
|
||||
if isinstance(response, DownloadError):
|
||||
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
||||
return None
|
||||
body = getattr(response, "body", None)
|
||||
if isinstance(body, (bytes, bytearray)):
|
||||
return bytes(body)
|
||||
if isinstance(response, MemoryDownloadResponse):
|
||||
return bytes(response.body)
|
||||
if isinstance(body, (str, Path)):
|
||||
path = Path(body)
|
||||
if path.is_file():
|
||||
try:
|
||||
return path.read_bytes()
|
||||
except OSError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||
key = key_obj.get("k") if isinstance(key_obj, dict) else None
|
||||
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
|
||||
if not all(isinstance(v, str) for v in (key, sha256, iv)):
|
||||
return None
|
||||
try:
|
||||
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||
except (EncryptionError, ValueError, TypeError):
|
||||
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||
return None
|
||||
|
||||
async def _fetch_media_attachment(
|
||||
self, room: MatrixRoom, event: MatrixMediaEvent,
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
"""Download, decrypt if needed, and persist a Matrix attachment."""
|
||||
atype = self._event_attachment_type(event)
|
||||
mime = self._event_mime(event)
|
||||
filename = self._event_filename(event, atype)
|
||||
mxc_url = getattr(event, "url", None)
|
||||
fail = _ATTACH_FAILED.format(filename)
|
||||
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return None, fail
|
||||
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
declared = self._event_declared_size_bytes(event)
|
||||
if declared is not None and declared > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
downloaded = await self._download_media_bytes(mxc_url)
|
||||
if downloaded is None:
|
||||
return None, fail
|
||||
|
||||
encrypted = self._is_encrypted_media_event(event)
|
||||
data = downloaded
|
||||
if encrypted:
|
||||
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
|
||||
return None, fail
|
||||
|
||||
if len(data) > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
path = self._build_attachment_path(event, atype, filename, mime)
|
||||
try:
|
||||
path.write_bytes(data)
|
||||
except OSError:
|
||||
return None, fail
|
||||
|
||||
attachment = {
|
||||
"type": atype, "mime": mime, "filename": filename,
|
||||
"event_id": str(getattr(event, "event_id", "") or ""),
|
||||
"encrypted": encrypted, "size_bytes": len(data),
|
||||
"path": str(path), "mxc_url": mxc_url,
|
||||
}
|
||||
return attachment, _ATTACH_MARKER.format(path)
|
||||
|
||||
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
|
||||
"""Build common metadata for text and media handlers."""
|
||||
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
|
||||
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
|
||||
meta["event_id"] = eid
|
||||
if thread := self._thread_metadata(event):
|
||||
meta.update(thread)
|
||||
return meta
|
||||
|
||||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content=event.body, metadata=self._base_metadata(room, event),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
|
||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||
parts: list[str] = []
|
||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||
parts.append(body.strip())
|
||||
if marker:
|
||||
parts.append(marker)
|
||||
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
meta = self._base_metadata(room, event)
|
||||
meta["attachments"] = []
|
||||
if attachment:
|
||||
meta["attachments"] = [attachment]
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content="\n".join(parts),
|
||||
media=[attachment["path"]] if attachment else [],
|
||||
metadata=meta,
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
@@ -322,7 +322,7 @@ class MochatChannel(BaseChannel):
|
||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||
content, msg.reply_to)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Mochat message: {e}")
|
||||
logger.error("Failed to send Mochat message: {}", e)
|
||||
|
||||
# ---- config / init helpers ---------------------------------------------
|
||||
|
||||
@@ -380,7 +380,7 @@ class MochatChannel(BaseChannel):
|
||||
|
||||
@client.event
|
||||
async def connect_error(data: Any) -> None:
|
||||
logger.error(f"Mochat websocket connect error: {data}")
|
||||
logger.error("Mochat websocket connect error: {}", data)
|
||||
|
||||
@client.on("claw.session.events")
|
||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||
@@ -407,7 +407,7 @@ class MochatChannel(BaseChannel):
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect Mochat websocket: {e}")
|
||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
@@ -444,7 +444,7 @@ class MochatChannel(BaseChannel):
|
||||
"limit": self.config.watch_limit,
|
||||
})
|
||||
if not ack.get("result"):
|
||||
logger.error(f"Mochat subscribeSessions failed: {ack.get('message', 'unknown error')}")
|
||||
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
|
||||
data = ack.get("data")
|
||||
@@ -466,7 +466,7 @@ class MochatChannel(BaseChannel):
|
||||
return True
|
||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||
if not ack.get("result"):
|
||||
logger.error(f"Mochat subscribePanels failed: {ack.get('message', 'unknown error')}")
|
||||
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -488,7 +488,7 @@ class MochatChannel(BaseChannel):
|
||||
try:
|
||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat refresh failed: {e}")
|
||||
logger.warning("Mochat refresh failed: {}", e)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
@@ -502,7 +502,7 @@ class MochatChannel(BaseChannel):
|
||||
try:
|
||||
response = await self._post_json("/api/claw/sessions/list", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat listSessions failed: {e}")
|
||||
logger.warning("Mochat listSessions failed: {}", e)
|
||||
return
|
||||
|
||||
sessions = response.get("sessions")
|
||||
@@ -536,7 +536,7 @@ class MochatChannel(BaseChannel):
|
||||
try:
|
||||
response = await self._post_json("/api/claw/groups/get", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat getWorkspaceGroup failed: {e}")
|
||||
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
||||
return
|
||||
|
||||
raw_panels = response.get("panels")
|
||||
@@ -598,7 +598,7 @@ class MochatChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat watch fallback error ({session_id}): {e}")
|
||||
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||
|
||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||
@@ -625,7 +625,7 @@ class MochatChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat panel polling error ({panel_id}): {e}")
|
||||
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
||||
await asyncio.sleep(sleep_s)
|
||||
|
||||
# ---- inbound event processing ------------------------------------------
|
||||
@@ -836,7 +836,7 @@ class MochatChannel(BaseChannel):
|
||||
try:
|
||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read Mochat cursor file: {e}")
|
||||
logger.warning("Failed to read Mochat cursor file: {}", e)
|
||||
return
|
||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||
if isinstance(cursors, dict):
|
||||
@@ -852,7 +852,7 @@ class MochatChannel(BaseChannel):
|
||||
"cursors": self._session_cursor,
|
||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save Mochat cursor file: {e}")
|
||||
logger.warning("Failed to save Mochat cursor file: {}", e)
|
||||
|
||||
# ---- HTTP helpers ------------------------------------------------------
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -13,7 +13,7 @@ from nanobot.config.schema import QQConfig
|
||||
|
||||
try:
|
||||
import botpy
|
||||
from botpy.message import C2CMessage, GroupMessage # 1. Import GroupMessage
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
QQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -28,27 +28,23 @@ if TYPE_CHECKING:
|
||||
|
||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
"""Create a botpy Client subclass bound to the given channel."""
|
||||
# 2. Ensure intents enable public_messages (required for group messages)
|
||||
intents = botpy.Intents(public_messages=True, direct_message=True)
|
||||
|
||||
class _Bot(botpy.Client):
|
||||
def __init__(self):
|
||||
super().__init__(intents=intents)
|
||||
# Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
|
||||
super().__init__(intents=intents, ext_handlers=False)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info(f"QQ bot ready: {self.robot.name}")
|
||||
logger.info("QQ bot ready: {}", self.robot.name)
|
||||
|
||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||
# C2C (Private) message
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
async def on_group_at_message_create(self, message: "GroupMessage"):
|
||||
# 3. Added: Listen for group @messages
|
||||
# Note: Official bots only receive messages @mentioning them unless privileged
|
||||
await channel._on_message(message, is_group=True)
|
||||
|
||||
async def on_direct_message_create(self, message):
|
||||
# Guild Direct Message
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
return _Bot
|
||||
@@ -64,10 +60,8 @@ class QQChannel(BaseChannel):
|
||||
self.config: QQConfig = config
|
||||
self._client: "botpy.Client | None" = None
|
||||
self._processed_ids: deque = deque(maxlen=1000)
|
||||
self._bot_task: asyncio.Task | None = None
|
||||
# Cache to track if chat_id is a group or individual to select the correct reply API
|
||||
# Format: {chat_id: "group" | "c2c"}
|
||||
self._chat_type_cache: Dict[str, str] = {}
|
||||
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
||||
self._chat_type_cache: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the QQ bot."""
|
||||
@@ -82,9 +76,8 @@ class QQChannel(BaseChannel):
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
|
||||
self._bot_task = asyncio.create_task(self._run_bot())
|
||||
logger.info("QQ bot started (C2C & Group supported)")
|
||||
await self._run_bot()
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
"""Run the bot connection with auto-reconnect."""
|
||||
@@ -92,7 +85,7 @@ class QQChannel(BaseChannel):
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.warning(f"QQ bot error: {e}")
|
||||
logger.warning("QQ bot error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
@@ -100,11 +93,10 @@ class QQChannel(BaseChannel):
|
||||
async def stop(self) -> None:
|
||||
"""Stop the QQ bot."""
|
||||
self._running = False
|
||||
if self._bot_task:
|
||||
self._bot_task.cancel()
|
||||
if self._client:
|
||||
try:
|
||||
await self._bot_task
|
||||
except asyncio.CancelledError:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("QQ bot stopped")
|
||||
|
||||
@@ -113,29 +105,29 @@ class QQChannel(BaseChannel):
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
# 4. Modified send logic: Check chat_id type to call the correct API
|
||||
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c") # Default to c2c
|
||||
|
||||
try:
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
self._msg_seq += 1
|
||||
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
if msg_type == "group":
|
||||
# Send group message
|
||||
await self._client.api.post_group_message(
|
||||
group_openid=msg.chat_id,
|
||||
msg_type=0,
|
||||
msg_id=msg.metadata.get("message_id"), # Reply to specific message ID (optional but recommended)
|
||||
content=msg.content
|
||||
msg_type=0,
|
||||
content=msg.content,
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
else:
|
||||
# Send C2C (private) message
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=0,
|
||||
msg_id=msg.metadata.get("message_id"),
|
||||
content=msg.content,
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending QQ message ({msg_type}): {e}")
|
||||
logger.error("Error sending QQ message: {}", e)
|
||||
|
||||
async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
|
||||
"""Handle incoming message from QQ."""
|
||||
@@ -149,17 +141,11 @@ class QQChannel(BaseChannel):
|
||||
if not content:
|
||||
return
|
||||
|
||||
# 5. Extract ID and cache type
|
||||
if is_group:
|
||||
# Group message: chat_id uses group_openid
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid # Sender's ID
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
|
||||
# Remove @bot text (optional, prevents Nanobot from treating the name as prompt)
|
||||
# content = content.replace("@BotName", "").strip()
|
||||
else:
|
||||
# Private message: chat_id uses user_openid
|
||||
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
@@ -170,5 +156,5 @@ class QQChannel(BaseChannel):
|
||||
content=content,
|
||||
metadata={"message_id": data.id},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling QQ message: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ message")
|
||||
|
||||
@@ -5,10 +5,11 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@@ -34,7 +35,7 @@ class SlackChannel(BaseChannel):
|
||||
logger.error("Slack bot/app token not configured")
|
||||
return
|
||||
if self.config.mode != "socket":
|
||||
logger.error(f"Unsupported Slack mode: {self.config.mode}")
|
||||
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
||||
return
|
||||
|
||||
self._running = True
|
||||
@@ -51,9 +52,9 @@ class SlackChannel(BaseChannel):
|
||||
try:
|
||||
auth = await self._web_client.auth_test()
|
||||
self._bot_user_id = auth.get("user_id")
|
||||
logger.info(f"Slack bot connected as {self._bot_user_id}")
|
||||
logger.info("Slack bot connected as {}", self._bot_user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Slack auth_test failed: {e}")
|
||||
logger.warning("Slack auth_test failed: {}", e)
|
||||
|
||||
logger.info("Starting Slack Socket Mode client...")
|
||||
await self._socket_client.connect()
|
||||
@@ -68,7 +69,7 @@ class SlackChannel(BaseChannel):
|
||||
try:
|
||||
await self._socket_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Slack socket close failed: {e}")
|
||||
logger.warning("Slack socket close failed: {}", e)
|
||||
self._socket_client = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
@@ -82,13 +83,26 @@ class SlackChannel(BaseChannel):
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||
use_thread = thread_ts and channel_type != "im"
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=msg.content or "",
|
||||
thread_ts=thread_ts if use_thread else None,
|
||||
)
|
||||
thread_ts_param = thread_ts if use_thread else None
|
||||
|
||||
if msg.content:
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=self._to_mrkdwn(msg.content),
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to upload file {}: {}", media_path, e)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending Slack message: {e}")
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
@@ -150,30 +164,39 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts") or event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
await self._web_client.reactions_add(
|
||||
channel=chat_id,
|
||||
name="eyes",
|
||||
name=self.config.react_emoji,
|
||||
timestamp=event.get("ts"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Slack reactions_add failed: {e}")
|
||||
logger.debug("Slack reactions_add failed: {}", e)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
"thread_ts": thread_ts,
|
||||
"channel_type": channel_type,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
"thread_ts": thread_ts,
|
||||
"channel_type": channel_type,
|
||||
},
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
if channel_type == "im":
|
||||
@@ -203,3 +226,55 @@ class SlackChannel(BaseChannel):
|
||||
if not text or not self._bot_user_id:
|
||||
return text
|
||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||
|
||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
|
||||
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
|
||||
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
|
||||
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
|
||||
|
||||
@classmethod
|
||||
def _to_mrkdwn(cls, text: str) -> str:
|
||||
"""Convert Markdown to Slack mrkdwn, including tables."""
|
||||
if not text:
|
||||
return ""
|
||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||
return cls._fixup_mrkdwn(slackify_markdown(text))
|
||||
|
||||
@classmethod
|
||||
def _fixup_mrkdwn(cls, text: str) -> str:
|
||||
"""Fix markdown artifacts that slackify_markdown misses."""
|
||||
code_blocks: list[str] = []
|
||||
|
||||
def _save_code(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(0))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = cls._CODE_FENCE_RE.sub(_save_code, text)
|
||||
text = cls._INLINE_CODE_RE.sub(_save_code, text)
|
||||
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
|
||||
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
|
||||
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
|
||||
|
||||
for i, block in enumerate(code_blocks):
|
||||
text = text.replace(f"\x00CB{i}\x00", block)
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _convert_table(match: re.Match) -> str:
|
||||
"""Convert a Markdown table to a Slack-readable list."""
|
||||
lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
return match.group(0)
|
||||
headers = [h.strip() for h in lines[0].strip("|").split("|")]
|
||||
start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
|
||||
rows: list[str] = []
|
||||
for line in lines[start:]:
|
||||
cells = [c.strip() for c in line.strip("|").split("|")]
|
||||
cells = (cells + [""] * len(headers))[: len(headers)]
|
||||
parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
|
||||
if parts:
|
||||
rows.append(" · ".join(parts))
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
@@ -4,20 +4,62 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, Update
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from telegram import BotCommand, ReplyParameters, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.session.manager import SessionManager
|
||||
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||
|
||||
|
||||
def _strip_md(s: str) -> str:
|
||||
"""Strip markdown inline formatting from text."""
|
||||
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
||||
s = re.sub(r'__(.+?)__', r'\1', s)
|
||||
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
||||
s = re.sub(r'`([^`]+)`', r'\1', s)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def _render_table_box(table_lines: list[str]) -> str:
|
||||
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
|
||||
|
||||
def dw(s: str) -> int:
|
||||
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
||||
|
||||
rows: list[list[str]] = []
|
||||
has_sep = False
|
||||
for line in table_lines:
|
||||
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
|
||||
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
||||
has_sep = True
|
||||
continue
|
||||
rows.append(cells)
|
||||
if not rows or not has_sep:
|
||||
return '\n'.join(table_lines)
|
||||
|
||||
ncols = max(len(r) for r in rows)
|
||||
for r in rows:
|
||||
r.extend([''] * (ncols - len(r)))
|
||||
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||
|
||||
def dr(cells: list[str]) -> str:
|
||||
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
||||
|
||||
out = [dr(rows[0])]
|
||||
out.append(' '.join('─' * w for w in widths))
|
||||
for row in rows[1:]:
|
||||
out.append(dr(row))
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
@@ -26,275 +68,433 @@ def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
|
||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||
code_blocks: list[str] = []
|
||||
def save_code_block(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(1))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
|
||||
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
|
||||
lines = text.split('\n')
|
||||
rebuilt: list[str] = []
|
||||
li = 0
|
||||
while li < len(lines):
|
||||
if re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl: list[str] = []
|
||||
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl.append(lines[li])
|
||||
li += 1
|
||||
box = _render_table_box(tbl)
|
||||
if box != '\n'.join(tbl):
|
||||
code_blocks.append(box)
|
||||
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
|
||||
else:
|
||||
rebuilt.extend(tbl)
|
||||
else:
|
||||
rebuilt.append(lines[li])
|
||||
li += 1
|
||||
text = '\n'.join(rebuilt)
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
inline_codes.append(m.group(1))
|
||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||
|
||||
|
||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||
|
||||
|
||||
# 3. Headers # Title -> just the title text
|
||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 5. Escape HTML special characters
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||
|
||||
|
||||
# 7. Bold **text** or __text__
|
||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||
|
||||
|
||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||
|
||||
|
||||
# 9. Strikethrough ~~text~~
|
||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||
|
||||
|
||||
# 10. Bullet lists - item -> • item
|
||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||
|
||||
|
||||
# 11. Restore inline code with HTML tags
|
||||
for i, code in enumerate(inline_codes):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||
|
||||
|
||||
# 12. Restore code blocks with HTML tags
|
||||
for i, code in enumerate(code_blocks):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
|
||||
|
||||
Simple and reliable - no webhook/public IP needed.
|
||||
"""
|
||||
|
||||
|
||||
name = "telegram"
|
||||
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("stop", "Stop the current task"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TelegramConfig,
|
||||
bus: MessageBus,
|
||||
groq_api_key: str = "",
|
||||
session_manager: SessionManager | None = None,
|
||||
):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self.session_manager = session_manager
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
|
||||
self._media_group_buffers: dict[str, dict] = {}
|
||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||
self._message_threads: dict[tuple[str, int], int] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
if not self.config.token:
|
||||
logger.error("Telegram bot token not configured")
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||
req = HTTPXRequest(
|
||||
connection_pool_size=16,
|
||||
pool_timeout=5.0,
|
||||
connect_timeout=30.0,
|
||||
read_timeout=30.0,
|
||||
proxy=self.config.proxy if self.config.proxy else None,
|
||||
)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
if self.config.proxy:
|
||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("reset", self._on_reset))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
MessageHandler(
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
self._on_message
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.info("Starting Telegram bot (polling mode)...")
|
||||
|
||||
|
||||
# Initialize and start polling
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
logger.info(f"Telegram bot @{bot_info.username} connected")
|
||||
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
try:
|
||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||
logger.debug("Telegram bot commands registered")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register bot commands: {e}")
|
||||
|
||||
logger.warning("Failed to register bot commands: {}", e)
|
||||
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
)
|
||||
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Telegram bot."""
|
||||
self._running = False
|
||||
|
||||
|
||||
# Cancel all typing indicators
|
||||
for chat_id in list(self._typing_tasks):
|
||||
self._stop_typing(chat_id)
|
||||
|
||||
|
||||
for task in self._media_group_tasks.values():
|
||||
task.cancel()
|
||||
self._media_group_tasks.clear()
|
||||
self._media_group_buffers.clear()
|
||||
|
||||
if self._app:
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
self._app = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_media_type(path: str) -> str:
|
||||
"""Guess media type from file extension."""
|
||||
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
|
||||
if ext in ("jpg", "jpeg", "png", "gif", "webp"):
|
||||
return "photo"
|
||||
if ext == "ogg":
|
||||
return "voice"
|
||||
if ext in ("mp3", "m4a", "wav", "aac"):
|
||||
return "audio"
|
||||
return "document"
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Telegram."""
|
||||
if not self._app:
|
||||
logger.warning("Telegram bot not running")
|
||||
return
|
||||
|
||||
# Stop typing indicator for this chat
|
||||
self._stop_typing(msg.chat_id)
|
||||
|
||||
|
||||
# Only stop typing indicator for final responses
|
||||
if not msg.metadata.get("_progress", False):
|
||||
self._stop_typing(msg.chat_id)
|
||||
|
||||
try:
|
||||
# chat_id should be the Telegram chat ID (integer)
|
||||
chat_id = int(msg.chat_id)
|
||||
# Convert markdown to Telegram HTML
|
||||
html_content = _markdown_to_telegram_html(msg.content)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=html_content,
|
||||
parse_mode="HTML"
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid chat_id: {msg.chat_id}")
|
||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||
return
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
message_thread_id = msg.metadata.get("message_thread_id")
|
||||
if message_thread_id is None and reply_to_message_id is not None:
|
||||
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
|
||||
thread_kwargs = {}
|
||||
if message_thread_id is not None:
|
||||
thread_kwargs["message_thread_id"] = message_thread_id
|
||||
|
||||
reply_params = None
|
||||
if self.config.reply_to_message:
|
||||
if reply_to_message_id:
|
||||
reply_params = ReplyParameters(
|
||||
message_id=reply_to_message_id,
|
||||
allow_sending_without_reply=True
|
||||
)
|
||||
|
||||
# Send media files
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
media_type = self._get_media_type(media_path)
|
||||
sender = {
|
||||
"photo": self._app.bot.send_photo,
|
||||
"voice": self._app.bot.send_voice,
|
||||
"audio": self._app.bot.send_audio,
|
||||
}.get(media_type, self._app.bot.send_document)
|
||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||
with open(media_path, 'rb') as f:
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
filename = media_path.rsplit("/", 1)[-1]
|
||||
logger.error("Failed to send media {}: {}", media_path, e)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"[Failed to send: {filename}]",
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
|
||||
# Send text content
|
||||
if msg.content and msg.content != "[empty message]":
|
||||
is_progress = msg.metadata.get("_progress", False)
|
||||
|
||||
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||
# Final response: simulate streaming via draft, then persist
|
||||
if not is_progress:
|
||||
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
||||
else:
|
||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||
|
||||
async def _send_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Send a plain text message with HTML fallback."""
|
||||
try:
|
||||
html = _markdown_to_telegram_html(text)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id, text=html, parse_mode="HTML",
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback to plain text if HTML parsing fails
|
||||
logger.warning(f"HTML parse failed, falling back to plain text: {e}")
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=int(msg.chat_id),
|
||||
text=msg.content
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error(f"Error sending Telegram message: {e2}")
|
||||
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
async def _send_with_streaming(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
||||
draft_id = int(time.time() * 1000) % (2**31)
|
||||
try:
|
||||
step = max(len(text) // 8, 40)
|
||||
for i in range(step, len(text), step):
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text[:i],
|
||||
)
|
||||
await asyncio.sleep(0.04)
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text,
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
except Exception:
|
||||
pass
|
||||
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
|
||||
user = update.effective_user
|
||||
await update.message.reply_text(
|
||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||
"Send me a message and I'll respond!\n"
|
||||
"Type /help to see available commands."
|
||||
)
|
||||
|
||||
async def _on_reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /reset command — clear conversation history."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
chat_id = str(update.message.chat_id)
|
||||
session_key = f"{self.name}:{chat_id}"
|
||||
|
||||
if self.session_manager is None:
|
||||
logger.warning("/reset called but session_manager is not available")
|
||||
await update.message.reply_text("⚠️ Session management is not available.")
|
||||
return
|
||||
|
||||
session = self.session_manager.get_or_create(session_key)
|
||||
msg_count = len(session.messages)
|
||||
session.clear()
|
||||
self.session_manager.save(session)
|
||||
|
||||
logger.info(f"Session reset for {session_key} (cleared {msg_count} messages)")
|
||||
await update.message.reply_text("🔄 Conversation history cleared. Let's start fresh!")
|
||||
|
||||
|
||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /help command — show available commands."""
|
||||
"""Handle /help command, bypassing ACL so all users can access it."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
help_text = (
|
||||
"🐈 <b>nanobot commands</b>\n\n"
|
||||
"/start — Start the bot\n"
|
||||
"/reset — Reset conversation history\n"
|
||||
"/help — Show this help message\n\n"
|
||||
"Just send me a text message to chat!"
|
||||
await update.message.reply_text(
|
||||
"🐈 nanobot commands:\n"
|
||||
"/new — Start a new conversation\n"
|
||||
"/stop — Stop the current task\n"
|
||||
"/help — Show available commands"
|
||||
)
|
||||
await update.message.reply_text(help_text, parse_mode="HTML")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _sender_id(user) -> str:
|
||||
"""Build sender_id with username for allowlist matching."""
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
@staticmethod
|
||||
def _derive_topic_session_key(message) -> str | None:
|
||||
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message.chat.type == "private" or message_thread_id is None:
|
||||
return None
|
||||
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||
|
||||
@staticmethod
|
||||
def _build_message_metadata(message, user) -> dict:
|
||||
"""Build common Telegram inbound metadata payload."""
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private",
|
||||
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||
}
|
||||
|
||||
def _remember_thread_context(self, message) -> None:
|
||||
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message_thread_id is None:
|
||||
return
|
||||
key = (str(message.chat_id), message.message_id)
|
||||
self._message_threads[key] = message_thread_id
|
||||
if len(self._message_threads) > 1000:
|
||||
self._message_threads.pop(next(iter(self._message_threads)))
|
||||
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
self._remember_thread_context(message)
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(user),
|
||||
chat_id=str(message.chat_id),
|
||||
content=message.text,
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
session_key=self._derive_topic_session_key(message),
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
|
||||
# Use stable numeric ID, but keep username for allowlist compatibility
|
||||
sender_id = str(user.id)
|
||||
if user.username:
|
||||
sender_id = f"{sender_id}|{user.username}"
|
||||
|
||||
sender_id = self._sender_id(user)
|
||||
self._remember_thread_context(message)
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
|
||||
# Text content
|
||||
if message.text:
|
||||
content_parts.append(message.text)
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
@@ -307,77 +507,112 @@ class TelegramChannel(BaseChannel):
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
||||
|
||||
ext = self._get_extension(
|
||||
media_type,
|
||||
getattr(media_file, 'mime_type', None),
|
||||
getattr(media_file, 'file_name', None),
|
||||
)
|
||||
# Save to workspace/media/
|
||||
from pathlib import Path
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
|
||||
# Handle voice transcription
|
||||
if media_type == "voice" or media_type == "audio":
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||
transcription = await transcriber.transcribe(file_path)
|
||||
if transcription:
|
||||
logger.info(f"Transcribed {media_type}: {transcription[:50]}...")
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
content_parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
logger.debug(f"Downloaded {media_type} to {file_path}")
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download media: {e}")
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
logger.debug(f"Telegram message from {sender_id}: {content[:50]}...")
|
||||
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
|
||||
metadata = self._build_message_metadata(message, user)
|
||||
session_key = self._derive_topic_session_key(message)
|
||||
|
||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||
if media_group_id := getattr(message, "media_group_id", None):
|
||||
key = f"{str_chat_id}:{media_group_id}"
|
||||
if key not in self._media_group_buffers:
|
||||
self._media_group_buffers[key] = {
|
||||
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||
"contents": [], "media": [],
|
||||
"metadata": metadata,
|
||||
"session_key": session_key,
|
||||
}
|
||||
self._start_typing(str_chat_id)
|
||||
buf = self._media_group_buffers[key]
|
||||
if content and content != "[empty message]":
|
||||
buf["contents"].append(content)
|
||||
buf["media"].extend(media_paths)
|
||||
if key not in self._media_group_tasks:
|
||||
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||
return
|
||||
|
||||
# Start typing indicator before processing
|
||||
self._start_typing(str_chat_id)
|
||||
|
||||
|
||||
# Forward to the message bus
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str_chat_id,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private"
|
||||
}
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
|
||||
async def _flush_media_group(self, key: str) -> None:
|
||||
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||
try:
|
||||
await asyncio.sleep(0.6)
|
||||
if not (buf := self._media_group_buffers.pop(key, None)):
|
||||
return
|
||||
content = "\n".join(buf["contents"]) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||
metadata=buf["metadata"],
|
||||
session_key=buf.get("session_key"),
|
||||
)
|
||||
finally:
|
||||
self._media_group_tasks.pop(key, None)
|
||||
|
||||
def _start_typing(self, chat_id: str) -> None:
|
||||
"""Start sending 'typing...' indicator for a chat."""
|
||||
# Cancel any existing typing task for this chat
|
||||
self._stop_typing(chat_id)
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||
|
||||
|
||||
def _stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
@@ -387,14 +622,19 @@ class TelegramChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Typing indicator stopped for {chat_id}: {e}")
|
||||
|
||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error(f"Telegram error: {context.error}")
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
|
||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||
"""Get file extension based on media type."""
|
||||
def _get_extension(
|
||||
self,
|
||||
media_type: str,
|
||||
mime_type: str | None,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get file extension based on media type or original filename."""
|
||||
if mime_type:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||
@@ -402,6 +642,14 @@ class TelegramChannel(BaseChannel):
|
||||
}
|
||||
if mime_type in ext_map:
|
||||
return ext_map[mime_type]
|
||||
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
return type_map.get(media_type, "")
|
||||
if ext := type_map.get(media_type, ""):
|
||||
return ext
|
||||
|
||||
if filename:
|
||||
from pathlib import Path
|
||||
|
||||
return "".join(Path(filename).suffixes)
|
||||
|
||||
return ""
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
import mimetypes
|
||||
from collections import OrderedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -15,131 +16,155 @@ from nanobot.config.schema import WhatsAppConfig
|
||||
class WhatsAppChannel(BaseChannel):
|
||||
"""
|
||||
WhatsApp channel that connects to a Node.js bridge.
|
||||
|
||||
|
||||
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
|
||||
Communication between Python and Node.js is via WebSocket.
|
||||
"""
|
||||
|
||||
|
||||
name = "whatsapp"
|
||||
|
||||
|
||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WhatsAppConfig = config
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||
import websockets
|
||||
|
||||
|
||||
bridge_url = self.config.bridge_url
|
||||
|
||||
logger.info(f"Connecting to WhatsApp bridge at {bridge_url}...")
|
||||
|
||||
|
||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||
|
||||
self._running = True
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
self._ws = ws
|
||||
# Send auth token if configured
|
||||
if self.config.bridge_token:
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
|
||||
# Listen for messages
|
||||
async for message in ws:
|
||||
try:
|
||||
await self._handle_bridge_message(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling bridge message: {e}")
|
||||
|
||||
logger.error("Error handling bridge message: {}", e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
logger.warning(f"WhatsApp bridge connection error: {e}")
|
||||
|
||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||
|
||||
if self._running:
|
||||
logger.info("Reconnecting in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WhatsApp channel."""
|
||||
self._running = False
|
||||
self._connected = False
|
||||
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WhatsApp."""
|
||||
if not self._ws or not self._connected:
|
||||
logger.warning("WhatsApp bridge not connected")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "send",
|
||||
"to": msg.chat_id,
|
||||
"text": msg.content
|
||||
}
|
||||
await self._ws.send(json.dumps(payload))
|
||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending WhatsApp message: {e}")
|
||||
|
||||
logger.error("Error sending WhatsApp message: {}", e)
|
||||
|
||||
async def _handle_bridge_message(self, raw: str) -> None:
|
||||
"""Handle a message from the bridge."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from bridge: {raw[:100]}")
|
||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||
return
|
||||
|
||||
|
||||
msg_type = data.get("type")
|
||||
|
||||
|
||||
if msg_type == "message":
|
||||
# Incoming message from WhatsApp
|
||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||
pn = data.get("pn", "")
|
||||
# New LID sytle typically:
|
||||
# New LID sytle typically:
|
||||
sender = data.get("sender", "")
|
||||
content = data.get("content", "")
|
||||
|
||||
message_id = data.get("id", "")
|
||||
|
||||
if message_id:
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Extract just the phone number or lid as chat_id
|
||||
user_id = pn if pn else sender
|
||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||
logger.info(f"Sender {sender}")
|
||||
|
||||
logger.info("Sender {}", sender)
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
logger.info(f"Voice message received from {sender_id}, but direct download from bridge is not yet supported.")
|
||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
|
||||
|
||||
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||
media_paths = data.get("media") or []
|
||||
|
||||
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
|
||||
if media_paths:
|
||||
for p in media_paths:
|
||||
mime, _ = mimetypes.guess_type(p)
|
||||
media_type = "image" if mime and mime.startswith("image/") else "file"
|
||||
media_tag = f"[{media_type}: {p}]"
|
||||
content = f"{content}\n{media_tag}" if content else media_tag
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender, # Use full LID for replies
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": data.get("id"),
|
||||
"message_id": message_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"is_group": data.get("isGroup", False)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
elif msg_type == "status":
|
||||
# Connection status update
|
||||
status = data.get("status")
|
||||
logger.info(f"WhatsApp status: {status}")
|
||||
|
||||
logger.info("WhatsApp status: {}", status)
|
||||
|
||||
if status == "connected":
|
||||
self._connected = True
|
||||
elif status == "disconnected":
|
||||
self._connected = False
|
||||
|
||||
|
||||
elif msg_type == "qr":
|
||||
# QR code for authentication
|
||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||
|
||||
|
||||
elif msg_type == "error":
|
||||
logger.error(f"WhatsApp bridge error: {data.get('error')}")
|
||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||
|
||||
Reference in New Issue
Block a user