refactor(matrix): use base media event filter for callbacks

- Replaces the explicit media event tuple with MATRIX_MEDIA_EVENT_FILTER
  based on
  media base classes: (RoomMessageMedia, RoomEncryptedMedia).
- Keeps MatrixMediaEvent as the static typing alias for media-specific
  handlers.
- Removes MatrixInboundEvent and uses RoomMessage in mention-related
  logic.
- Adds regression tests for:
  - callback registration using MATRIX_MEDIA_EVENT_FILTER
  - ensuring RoomMessageText is not matched by the media filter.
This commit is contained in:
Alexander Minges
2026-02-12 22:58:53 +01:00
committed by Alexander Minges
parent 1103f000fc
commit 10de3bf329
2 changed files with 79 additions and 36 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
import mimetypes
from pathlib import Path
from typing import Any
from typing import Any, TypeAlias
import nh3
from loguru import logger
@@ -15,15 +15,10 @@ from nio import (
JoinError,
MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedAudio,
RoomEncryptedFile,
RoomEncryptedImage,
RoomEncryptedVideo,
RoomMessageAudio,
RoomMessageFile,
RoomMessageImage,
RoomEncryptedMedia,
RoomMessage,
RoomMessageMedia,
RoomMessageText,
RoomMessageVideo,
RoomSendError,
RoomTypingError,
SyncError,
@@ -51,16 +46,10 @@ MATRIX_ATTACHMENT_TOO_LARGE_TEMPLATE = "[attachment: {} - too large]"
MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]"
MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment"
MATRIX_MEDIA_EVENT_TYPES = (
RoomMessageImage,
RoomMessageFile,
RoomMessageAudio,
RoomMessageVideo,
RoomEncryptedImage,
RoomEncryptedFile,
RoomEncryptedAudio,
RoomEncryptedVideo,
)
# Runtime callback filter for nio event dispatch (checked via isinstance).
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
# Static typing alias for media-specific handlers/helpers.
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
# Markdown renderer policy:
# https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes
@@ -345,7 +334,7 @@ class MatrixChannel(BaseChannel):
def _register_event_callbacks(self) -> None:
"""Register Matrix event callbacks used by this channel."""
self.client.add_event_callback(self._on_message, RoomMessageText)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_TYPES)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
self.client.add_event_callback(self._on_room_invite, InviteEvent)
def _register_response_callbacks(self) -> None:
@@ -460,7 +449,7 @@ class MatrixChannel(BaseChannel):
member_count = getattr(room, "member_count", None)
return isinstance(member_count, int) and member_count <= 2
def _is_bot_mentioned_from_mx_mentions(self, event: Any) -> bool:
def _is_bot_mentioned_from_mx_mentions(self, event: RoomMessage) -> bool:
"""Resolve mentions strictly from Matrix-native m.mentions payload."""
source = getattr(event, "source", None)
if not isinstance(source, dict):
@@ -480,7 +469,7 @@ class MatrixChannel(BaseChannel):
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
def _should_process_message(self, room: MatrixRoom, event: Any) -> bool:
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
"""Apply sender and room policy checks before processing Matrix messages."""
if not self.is_allowed(event.sender):
return False
@@ -505,7 +494,7 @@ class MatrixChannel(BaseChannel):
return media_dir
@staticmethod
def _event_source_content(event: Any) -> dict[str, Any]:
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
"""Extract Matrix event content payload when available."""
source = getattr(event, "source", None)
if not isinstance(source, dict):
@@ -513,7 +502,47 @@ class MatrixChannel(BaseChannel):
content = source.get("content")
return content if isinstance(content, dict) else {}
def _event_attachment_type(self, event: Any) -> str:
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
"""Return thread root event_id if this message is inside a thread."""
content = self._event_source_content(event)
relates_to = content.get("m.relates_to")
if not isinstance(relates_to, dict):
return None
if relates_to.get("rel_type") != "m.thread":
return None
root_id = relates_to.get("event_id")
return root_id if isinstance(root_id, str) and root_id else None
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
"""Build metadata used to reply within a thread."""
root_id = self._event_thread_root_id(event)
if not root_id:
return None
reply_to = getattr(event, "event_id", None)
meta: dict[str, str] = {"thread_root_event_id": root_id}
if isinstance(reply_to, str) and reply_to:
meta["thread_reply_to_event_id"] = reply_to
return meta
@staticmethod
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
"""Build m.relates_to payload for Matrix thread replies."""
if not metadata:
return None
root_id = metadata.get("thread_root_event_id")
if not isinstance(root_id, str) or not root_id:
return None
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
if not isinstance(reply_to, str) or not reply_to:
return None
return {
"rel_type": "m.thread",
"event_id": root_id,
"m.in_reply_to": {"event_id": reply_to},
"is_falling_back": True,
}
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
"""Map Matrix event payload/type to a stable attachment kind."""
msgtype = self._event_source_content(event).get("msgtype")
if msgtype == "m.image":
@@ -535,7 +564,7 @@ class MatrixChannel(BaseChannel):
return "file"
@staticmethod
def _is_encrypted_media_event(event: Any) -> bool:
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
"""Return True for encrypted Matrix media events."""
return (
isinstance(getattr(event, "key", None), dict)
@@ -543,7 +572,7 @@ class MatrixChannel(BaseChannel):
and isinstance(getattr(event, "iv", None), str)
)
def _event_declared_size_bytes(self, event: Any) -> int | None:
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
"""Return declared media size from Matrix event info, if present."""
info = self._event_source_content(event).get("info")
if not isinstance(info, dict):
@@ -553,7 +582,7 @@ class MatrixChannel(BaseChannel):
return size
return None
def _event_mime(self, event: Any) -> str | None:
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
"""Best-effort MIME extraction from Matrix media event."""
info = self._event_source_content(event).get("info")
if isinstance(info, dict):
@@ -566,7 +595,7 @@ class MatrixChannel(BaseChannel):
return mime
return None
def _event_filename(self, event: Any, attachment_type: str) -> str:
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
"""Build a safe filename for a Matrix attachment."""
body = getattr(event, "body", None)
if isinstance(body, str) and body.strip():
@@ -577,7 +606,7 @@ class MatrixChannel(BaseChannel):
def _build_attachment_path(
self,
event: Any,
event: MatrixMediaEvent,
attachment_type: str,
filename: str,
mime: str | None,
@@ -637,7 +666,7 @@ class MatrixChannel(BaseChannel):
)
return None
def _decrypt_media_bytes(self, event: Any, ciphertext: bytes) -> bytes | None:
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
"""Decrypt encrypted Matrix attachment bytes."""
key_obj = getattr(event, "key", None)
hashes = getattr(event, "hashes", None)
@@ -666,7 +695,7 @@ class MatrixChannel(BaseChannel):
async def _fetch_media_attachment(
self,
room: MatrixRoom,
event: Any,
event: MatrixMediaEvent,
) -> tuple[dict[str, Any] | None, str]:
"""Download and prepare a Matrix attachment for inbound processing."""
attachment_type = self._event_attachment_type(event)
@@ -683,10 +712,7 @@ class MatrixChannel(BaseChannel):
return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename)
declared_size = self._event_declared_size_bytes(event)
if (
declared_size is not None
and declared_size > self.config.max_inbound_media_bytes
):
if declared_size is not None and declared_size > self.config.max_inbound_media_bytes:
logger.warning(
"Matrix attachment skipped in room {}: declared size {} exceeds limit {}",
room.room_id,
@@ -765,7 +791,7 @@ class MatrixChannel(BaseChannel):
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
raise
async def _on_media_message(self, room: MatrixRoom, event: Any) -> None:
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
"""Handle inbound Matrix media events and forward local attachment paths."""
if event.sender == self.config.user_id:
return

View File

@@ -159,6 +159,23 @@ async def test_start_skips_load_store_when_device_id_missing(
await channel.stop()
@pytest.mark.asyncio
async def test_register_event_callbacks_uses_media_base_filter() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._register_event_callbacks()
assert len(client.callbacks) == 3
assert client.callbacks[1][0] == channel._on_media_message
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
def test_media_event_filter_does_not_match_text_events() -> None:
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
@pytest.mark.asyncio
async def test_start_disables_e2ee_when_configured(
monkeypatch, tmp_path