merge origin/main into pr-1567
This commit is contained in:
@@ -33,6 +33,7 @@ class DiscordChannel(BaseChannel):
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
@@ -149,6 +150,10 @@ class DiscordChannel(BaseChannel):
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
@@ -205,6 +210,7 @@ class DiscordChannel(BaseChannel):
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
@@ -212,6 +218,11 @@ class DiscordChannel(BaseChannel):
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
@@ -248,11 +259,32 @@ class DiscordChannel(BaseChannel):
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": payload.get("guild_id"),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
Reference in New Issue
Block a user