refactor(matrix): use base media event filter for callbacks

- Replaces the explicit media event tuple with MATRIX_MEDIA_EVENT_FILTER
  based on
  media base classes: (RoomMessageMedia, RoomEncryptedMedia).
- Keeps MatrixMediaEvent as the static typing alias for media-specific
  handlers.
- Removes MatrixInboundEvent and uses RoomMessage in mention-related
  logic.
- Adds regression tests for:
  - callback registration using MATRIX_MEDIA_EVENT_FILTER
  - ensuring RoomMessageText is not matched by the media filter.
This commit is contained in:
Alexander Minges
2026-02-12 22:58:53 +01:00
committed by Alexander Minges
parent 1103f000fc
commit 10de3bf329
2 changed files with 79 additions and 36 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
import logging import logging
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, TypeAlias
import nh3 import nh3
from loguru import logger from loguru import logger
@@ -15,15 +15,10 @@ from nio import (
JoinError, JoinError,
MatrixRoom, MatrixRoom,
MemoryDownloadResponse, MemoryDownloadResponse,
RoomEncryptedAudio, RoomEncryptedMedia,
RoomEncryptedFile, RoomMessage,
RoomEncryptedImage, RoomMessageMedia,
RoomEncryptedVideo,
RoomMessageAudio,
RoomMessageFile,
RoomMessageImage,
RoomMessageText, RoomMessageText,
RoomMessageVideo,
RoomSendError, RoomSendError,
RoomTypingError, RoomTypingError,
SyncError, SyncError,
@@ -51,16 +46,10 @@ MATRIX_ATTACHMENT_TOO_LARGE_TEMPLATE = "[attachment: {} - too large]"
MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]" MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]"
MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment" MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment"
MATRIX_MEDIA_EVENT_TYPES = ( # Runtime callback filter for nio event dispatch (checked via isinstance).
RoomMessageImage, MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
RoomMessageFile, # Static typing alias for media-specific handlers/helpers.
RoomMessageAudio, MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
RoomMessageVideo,
RoomEncryptedImage,
RoomEncryptedFile,
RoomEncryptedAudio,
RoomEncryptedVideo,
)
# Markdown renderer policy: # Markdown renderer policy:
# https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes # https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes
@@ -345,7 +334,7 @@ class MatrixChannel(BaseChannel):
def _register_event_callbacks(self) -> None: def _register_event_callbacks(self) -> None:
"""Register Matrix event callbacks used by this channel.""" """Register Matrix event callbacks used by this channel."""
self.client.add_event_callback(self._on_message, RoomMessageText) self.client.add_event_callback(self._on_message, RoomMessageText)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_TYPES) self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
self.client.add_event_callback(self._on_room_invite, InviteEvent) self.client.add_event_callback(self._on_room_invite, InviteEvent)
def _register_response_callbacks(self) -> None: def _register_response_callbacks(self) -> None:
@@ -460,7 +449,7 @@ class MatrixChannel(BaseChannel):
member_count = getattr(room, "member_count", None) member_count = getattr(room, "member_count", None)
return isinstance(member_count, int) and member_count <= 2 return isinstance(member_count, int) and member_count <= 2
def _is_bot_mentioned_from_mx_mentions(self, event: Any) -> bool: def _is_bot_mentioned_from_mx_mentions(self, event: RoomMessage) -> bool:
"""Resolve mentions strictly from Matrix-native m.mentions payload.""" """Resolve mentions strictly from Matrix-native m.mentions payload."""
source = getattr(event, "source", None) source = getattr(event, "source", None)
if not isinstance(source, dict): if not isinstance(source, dict):
@@ -480,7 +469,7 @@ class MatrixChannel(BaseChannel):
return bool(self.config.allow_room_mentions and mentions.get("room") is True) return bool(self.config.allow_room_mentions and mentions.get("room") is True)
def _should_process_message(self, room: MatrixRoom, event: Any) -> bool: def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
"""Apply sender and room policy checks before processing Matrix messages.""" """Apply sender and room policy checks before processing Matrix messages."""
if not self.is_allowed(event.sender): if not self.is_allowed(event.sender):
return False return False
@@ -505,7 +494,7 @@ class MatrixChannel(BaseChannel):
return media_dir return media_dir
@staticmethod @staticmethod
def _event_source_content(event: Any) -> dict[str, Any]: def _event_source_content(event: RoomMessage) -> dict[str, Any]:
"""Extract Matrix event content payload when available.""" """Extract Matrix event content payload when available."""
source = getattr(event, "source", None) source = getattr(event, "source", None)
if not isinstance(source, dict): if not isinstance(source, dict):
@@ -513,7 +502,47 @@ class MatrixChannel(BaseChannel):
content = source.get("content") content = source.get("content")
return content if isinstance(content, dict) else {} return content if isinstance(content, dict) else {}
def _event_attachment_type(self, event: Any) -> str: def _event_thread_root_id(self, event: RoomMessage) -> str | None:
"""Return thread root event_id if this message is inside a thread."""
content = self._event_source_content(event)
relates_to = content.get("m.relates_to")
if not isinstance(relates_to, dict):
return None
if relates_to.get("rel_type") != "m.thread":
return None
root_id = relates_to.get("event_id")
return root_id if isinstance(root_id, str) and root_id else None
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
"""Build metadata used to reply within a thread."""
root_id = self._event_thread_root_id(event)
if not root_id:
return None
reply_to = getattr(event, "event_id", None)
meta: dict[str, str] = {"thread_root_event_id": root_id}
if isinstance(reply_to, str) and reply_to:
meta["thread_reply_to_event_id"] = reply_to
return meta
@staticmethod
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
"""Build m.relates_to payload for Matrix thread replies."""
if not metadata:
return None
root_id = metadata.get("thread_root_event_id")
if not isinstance(root_id, str) or not root_id:
return None
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
if not isinstance(reply_to, str) or not reply_to:
return None
return {
"rel_type": "m.thread",
"event_id": root_id,
"m.in_reply_to": {"event_id": reply_to},
"is_falling_back": True,
}
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
"""Map Matrix event payload/type to a stable attachment kind.""" """Map Matrix event payload/type to a stable attachment kind."""
msgtype = self._event_source_content(event).get("msgtype") msgtype = self._event_source_content(event).get("msgtype")
if msgtype == "m.image": if msgtype == "m.image":
@@ -535,7 +564,7 @@ class MatrixChannel(BaseChannel):
return "file" return "file"
@staticmethod @staticmethod
def _is_encrypted_media_event(event: Any) -> bool: def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
"""Return True for encrypted Matrix media events.""" """Return True for encrypted Matrix media events."""
return ( return (
isinstance(getattr(event, "key", None), dict) isinstance(getattr(event, "key", None), dict)
@@ -543,7 +572,7 @@ class MatrixChannel(BaseChannel):
and isinstance(getattr(event, "iv", None), str) and isinstance(getattr(event, "iv", None), str)
) )
def _event_declared_size_bytes(self, event: Any) -> int | None: def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
"""Return declared media size from Matrix event info, if present.""" """Return declared media size from Matrix event info, if present."""
info = self._event_source_content(event).get("info") info = self._event_source_content(event).get("info")
if not isinstance(info, dict): if not isinstance(info, dict):
@@ -553,7 +582,7 @@ class MatrixChannel(BaseChannel):
return size return size
return None return None
def _event_mime(self, event: Any) -> str | None: def _event_mime(self, event: MatrixMediaEvent) -> str | None:
"""Best-effort MIME extraction from Matrix media event.""" """Best-effort MIME extraction from Matrix media event."""
info = self._event_source_content(event).get("info") info = self._event_source_content(event).get("info")
if isinstance(info, dict): if isinstance(info, dict):
@@ -566,7 +595,7 @@ class MatrixChannel(BaseChannel):
return mime return mime
return None return None
def _event_filename(self, event: Any, attachment_type: str) -> str: def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
"""Build a safe filename for a Matrix attachment.""" """Build a safe filename for a Matrix attachment."""
body = getattr(event, "body", None) body = getattr(event, "body", None)
if isinstance(body, str) and body.strip(): if isinstance(body, str) and body.strip():
@@ -577,7 +606,7 @@ class MatrixChannel(BaseChannel):
def _build_attachment_path( def _build_attachment_path(
self, self,
event: Any, event: MatrixMediaEvent,
attachment_type: str, attachment_type: str,
filename: str, filename: str,
mime: str | None, mime: str | None,
@@ -637,7 +666,7 @@ class MatrixChannel(BaseChannel):
) )
return None return None
def _decrypt_media_bytes(self, event: Any, ciphertext: bytes) -> bytes | None: def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
"""Decrypt encrypted Matrix attachment bytes.""" """Decrypt encrypted Matrix attachment bytes."""
key_obj = getattr(event, "key", None) key_obj = getattr(event, "key", None)
hashes = getattr(event, "hashes", None) hashes = getattr(event, "hashes", None)
@@ -666,7 +695,7 @@ class MatrixChannel(BaseChannel):
async def _fetch_media_attachment( async def _fetch_media_attachment(
self, self,
room: MatrixRoom, room: MatrixRoom,
event: Any, event: MatrixMediaEvent,
) -> tuple[dict[str, Any] | None, str]: ) -> tuple[dict[str, Any] | None, str]:
"""Download and prepare a Matrix attachment for inbound processing.""" """Download and prepare a Matrix attachment for inbound processing."""
attachment_type = self._event_attachment_type(event) attachment_type = self._event_attachment_type(event)
@@ -683,10 +712,7 @@ class MatrixChannel(BaseChannel):
return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename) return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename)
declared_size = self._event_declared_size_bytes(event) declared_size = self._event_declared_size_bytes(event)
if ( if declared_size is not None and declared_size > self.config.max_inbound_media_bytes:
declared_size is not None
and declared_size > self.config.max_inbound_media_bytes
):
logger.warning( logger.warning(
"Matrix attachment skipped in room {}: declared size {} exceeds limit {}", "Matrix attachment skipped in room {}: declared size {} exceeds limit {}",
room.room_id, room.room_id,
@@ -765,7 +791,7 @@ class MatrixChannel(BaseChannel):
await self._stop_typing_keepalive(room.room_id, clear_typing=True) await self._stop_typing_keepalive(room.room_id, clear_typing=True)
raise raise
async def _on_media_message(self, room: MatrixRoom, event: Any) -> None: async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
"""Handle inbound Matrix media events and forward local attachment paths.""" """Handle inbound Matrix media events and forward local attachment paths."""
if event.sender == self.config.user_id: if event.sender == self.config.user_id:
return return

View File

@@ -159,6 +159,23 @@ async def test_start_skips_load_store_when_device_id_missing(
await channel.stop() await channel.stop()
@pytest.mark.asyncio
async def test_register_event_callbacks_uses_media_base_filter() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._register_event_callbacks()
assert len(client.callbacks) == 3
assert client.callbacks[1][0] == channel._on_media_message
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
def test_media_event_filter_does_not_match_text_events() -> None:
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_disables_e2ee_when_configured( async def test_start_disables_e2ee_when_configured(
monkeypatch, tmp_path monkeypatch, tmp_path