diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index fcff534..51df4e8 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -2,7 +2,7 @@ import asyncio import logging import mimetypes from pathlib import Path -from typing import Any +from typing import Any, TypeAlias import nh3 from loguru import logger @@ -15,15 +15,10 @@ from nio import ( JoinError, MatrixRoom, MemoryDownloadResponse, - RoomEncryptedAudio, - RoomEncryptedFile, - RoomEncryptedImage, - RoomEncryptedVideo, - RoomMessageAudio, - RoomMessageFile, - RoomMessageImage, + RoomEncryptedMedia, + RoomMessage, + RoomMessageMedia, RoomMessageText, - RoomMessageVideo, RoomSendError, RoomTypingError, SyncError, @@ -51,16 +46,10 @@ MATRIX_ATTACHMENT_TOO_LARGE_TEMPLATE = "[attachment: {} - too large]" MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]" MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment" -MATRIX_MEDIA_EVENT_TYPES = ( - RoomMessageImage, - RoomMessageFile, - RoomMessageAudio, - RoomMessageVideo, - RoomEncryptedImage, - RoomEncryptedFile, - RoomEncryptedAudio, - RoomEncryptedVideo, -) +# Runtime callback filter for nio event dispatch (checked via isinstance). +MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia) +# Static typing alias for media-specific handlers/helpers. +MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia # Markdown renderer policy: # https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes @@ -345,7 +334,7 @@ class MatrixChannel(BaseChannel): def _register_event_callbacks(self) -> None: """Register Matrix event callbacks used by this channel.""" self.client.add_event_callback(self._on_message, RoomMessageText) - self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_TYPES) + self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) self.client.add_event_callback(self._on_room_invite, InviteEvent) def _register_response_callbacks(self) -> None: @@ -460,7 +449,7 @@ class MatrixChannel(BaseChannel): member_count = getattr(room, "member_count", None) return isinstance(member_count, int) and member_count <= 2 - def _is_bot_mentioned_from_mx_mentions(self, event: Any) -> bool: + def _is_bot_mentioned_from_mx_mentions(self, event: RoomMessage) -> bool: """Resolve mentions strictly from Matrix-native m.mentions payload.""" source = getattr(event, "source", None) if not isinstance(source, dict): @@ -480,7 +469,7 @@ class MatrixChannel(BaseChannel): return bool(self.config.allow_room_mentions and mentions.get("room") is True) - def _should_process_message(self, room: MatrixRoom, event: Any) -> bool: + def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool: """Apply sender and room policy checks before processing Matrix messages.""" if not self.is_allowed(event.sender): return False @@ -505,7 +494,7 @@ class MatrixChannel(BaseChannel): return media_dir @staticmethod - def _event_source_content(event: Any) -> dict[str, Any]: + def _event_source_content(event: RoomMessage) -> dict[str, Any]: """Extract Matrix event content payload when available.""" source = getattr(event, "source", None) if not isinstance(source, dict): @@ -513,7 +502,47 @@ class MatrixChannel(BaseChannel): content = source.get("content") return content if isinstance(content, dict) else {} - def _event_attachment_type(self, event: Any) -> str: + def _event_thread_root_id(self, event: RoomMessage) -> str | None: + """Return thread root event_id if this message is inside a thread.""" + content = self._event_source_content(event) + relates_to = content.get("m.relates_to") + if not isinstance(relates_to, dict): + return None + if relates_to.get("rel_type") != "m.thread": + return None + root_id = relates_to.get("event_id") + return root_id if isinstance(root_id, str) and root_id else None + + def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None: + """Build metadata used to reply within a thread.""" + root_id = self._event_thread_root_id(event) + if not root_id: + return None + reply_to = getattr(event, "event_id", None) + meta: dict[str, str] = {"thread_root_event_id": root_id} + if isinstance(reply_to, str) and reply_to: + meta["thread_reply_to_event_id"] = reply_to + return meta + + @staticmethod + def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None: + """Build m.relates_to payload for Matrix thread replies.""" + if not metadata: + return None + root_id = metadata.get("thread_root_event_id") + if not isinstance(root_id, str) or not root_id: + return None + reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id") + if not isinstance(reply_to, str) or not reply_to: + return None + return { + "rel_type": "m.thread", + "event_id": root_id, + "m.in_reply_to": {"event_id": reply_to}, + "is_falling_back": True, + } + + def _event_attachment_type(self, event: MatrixMediaEvent) -> str: """Map Matrix event payload/type to a stable attachment kind.""" msgtype = self._event_source_content(event).get("msgtype") if msgtype == "m.image": @@ -535,7 +564,7 @@ class MatrixChannel(BaseChannel): return "file" @staticmethod - def _is_encrypted_media_event(event: Any) -> bool: + def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool: """Return True for encrypted Matrix media events.""" return ( isinstance(getattr(event, "key", None), dict) @@ -543,7 +572,7 @@ class MatrixChannel(BaseChannel): and isinstance(getattr(event, "iv", None), str) ) - def _event_declared_size_bytes(self, event: Any) -> int | None: + def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None: """Return declared media size from Matrix event info, if present.""" info = self._event_source_content(event).get("info") if not isinstance(info, dict): @@ -553,7 +582,7 @@ class MatrixChannel(BaseChannel): return size return None - def _event_mime(self, event: Any) -> str | None: + def _event_mime(self, event: MatrixMediaEvent) -> str | None: """Best-effort MIME extraction from Matrix media event.""" info = self._event_source_content(event).get("info") if isinstance(info, dict): @@ -566,7 +595,7 @@ class MatrixChannel(BaseChannel): return mime return None - def _event_filename(self, event: Any, attachment_type: str) -> str: + def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str: """Build a safe filename for a Matrix attachment.""" body = getattr(event, "body", None) if isinstance(body, str) and body.strip(): @@ -577,7 +606,7 @@ class MatrixChannel(BaseChannel): def _build_attachment_path( self, - event: Any, + event: MatrixMediaEvent, attachment_type: str, filename: str, mime: str | None, @@ -637,7 +666,7 @@ class MatrixChannel(BaseChannel): ) return None - def _decrypt_media_bytes(self, event: Any, ciphertext: bytes) -> bytes | None: + def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None: """Decrypt encrypted Matrix attachment bytes.""" key_obj = getattr(event, "key", None) hashes = getattr(event, "hashes", None) @@ -666,7 +695,7 @@ class MatrixChannel(BaseChannel): async def _fetch_media_attachment( self, room: MatrixRoom, - event: Any, + event: MatrixMediaEvent, ) -> tuple[dict[str, Any] | None, str]: """Download and prepare a Matrix attachment for inbound processing.""" attachment_type = self._event_attachment_type(event) @@ -683,10 +712,7 @@ class MatrixChannel(BaseChannel): return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename) declared_size = self._event_declared_size_bytes(event) - if ( - declared_size is not None - and declared_size > self.config.max_inbound_media_bytes - ): + if declared_size is not None and declared_size > self.config.max_inbound_media_bytes: logger.warning( "Matrix attachment skipped in room {}: declared size {} exceeds limit {}", room.room_id, @@ -765,7 +791,7 @@ class MatrixChannel(BaseChannel): await self._stop_typing_keepalive(room.room_id, clear_typing=True) raise - async def _on_media_message(self, room: MatrixRoom, event: Any) -> None: + async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None: """Handle inbound Matrix media events and forward local attachment paths.""" if event.sender == self.config.user_id: return diff --git a/tests/test_matrix_channel.py b/tests/test_matrix_channel.py index 6ea955d..164ec2e 100644 --- a/tests/test_matrix_channel.py +++ b/tests/test_matrix_channel.py @@ -159,6 +159,23 @@ async def test_start_skips_load_store_when_device_id_missing( await channel.stop() +@pytest.mark.asyncio +async def test_register_event_callbacks_uses_media_base_filter() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._register_event_callbacks() + + assert len(client.callbacks) == 3 + assert client.callbacks[1][0] == channel._on_media_message + assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER + + +def test_media_event_filter_does_not_match_text_events() -> None: + assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER) + + @pytest.mark.asyncio async def test_start_disables_e2ee_when_configured( monkeypatch, tmp_path