refactor(matrix): use base media event filter for callbacks
- Replaces the explicit media event tuple with MATRIX_MEDIA_EVENT_FILTER based on media base classes: (RoomMessageMedia, RoomEncryptedMedia). - Keeps MatrixMediaEvent as the static typing alias for media-specific handlers. - Removes MatrixInboundEvent and uses RoomMessage in mention-related logic. - Adds regression tests for: - callback registration using MATRIX_MEDIA_EVENT_FILTER - ensuring RoomMessageText is not matched by the media filter.
This commit is contained in:
committed by
Alexander Minges
parent
1103f000fc
commit
10de3bf329
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import nh3
|
||||
from loguru import logger
|
||||
@@ -15,15 +15,10 @@ from nio import (
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedAudio,
|
||||
RoomEncryptedFile,
|
||||
RoomEncryptedImage,
|
||||
RoomEncryptedVideo,
|
||||
RoomMessageAudio,
|
||||
RoomMessageFile,
|
||||
RoomMessageImage,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomMessageVideo,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
@@ -51,16 +46,10 @@ MATRIX_ATTACHMENT_TOO_LARGE_TEMPLATE = "[attachment: {} - too large]"
|
||||
MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]"
|
||||
MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment"
|
||||
|
||||
MATRIX_MEDIA_EVENT_TYPES = (
|
||||
RoomMessageImage,
|
||||
RoomMessageFile,
|
||||
RoomMessageAudio,
|
||||
RoomMessageVideo,
|
||||
RoomEncryptedImage,
|
||||
RoomEncryptedFile,
|
||||
RoomEncryptedAudio,
|
||||
RoomEncryptedVideo,
|
||||
)
|
||||
# Runtime callback filter for nio event dispatch (checked via isinstance).
|
||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||
# Static typing alias for media-specific handlers/helpers.
|
||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||
|
||||
# Markdown renderer policy:
|
||||
# https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes
|
||||
@@ -345,7 +334,7 @@ class MatrixChannel(BaseChannel):
|
||||
def _register_event_callbacks(self) -> None:
|
||||
"""Register Matrix event callbacks used by this channel."""
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_TYPES)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
@@ -460,7 +449,7 @@ class MatrixChannel(BaseChannel):
|
||||
member_count = getattr(room, "member_count", None)
|
||||
return isinstance(member_count, int) and member_count <= 2
|
||||
|
||||
def _is_bot_mentioned_from_mx_mentions(self, event: Any) -> bool:
|
||||
def _is_bot_mentioned_from_mx_mentions(self, event: RoomMessage) -> bool:
|
||||
"""Resolve mentions strictly from Matrix-native m.mentions payload."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
@@ -480,7 +469,7 @@ class MatrixChannel(BaseChannel):
|
||||
|
||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||
|
||||
def _should_process_message(self, room: MatrixRoom, event: Any) -> bool:
|
||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||
"""Apply sender and room policy checks before processing Matrix messages."""
|
||||
if not self.is_allowed(event.sender):
|
||||
return False
|
||||
@@ -505,7 +494,7 @@ class MatrixChannel(BaseChannel):
|
||||
return media_dir
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: Any) -> dict[str, Any]:
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
"""Extract Matrix event content payload when available."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
@@ -513,7 +502,47 @@ class MatrixChannel(BaseChannel):
|
||||
content = source.get("content")
|
||||
return content if isinstance(content, dict) else {}
|
||||
|
||||
def _event_attachment_type(self, event: Any) -> str:
|
||||
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||
"""Return thread root event_id if this message is inside a thread."""
|
||||
content = self._event_source_content(event)
|
||||
relates_to = content.get("m.relates_to")
|
||||
if not isinstance(relates_to, dict):
|
||||
return None
|
||||
if relates_to.get("rel_type") != "m.thread":
|
||||
return None
|
||||
root_id = relates_to.get("event_id")
|
||||
return root_id if isinstance(root_id, str) and root_id else None
|
||||
|
||||
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||
"""Build metadata used to reply within a thread."""
|
||||
root_id = self._event_thread_root_id(event)
|
||||
if not root_id:
|
||||
return None
|
||||
reply_to = getattr(event, "event_id", None)
|
||||
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||
if isinstance(reply_to, str) and reply_to:
|
||||
meta["thread_reply_to_event_id"] = reply_to
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Build m.relates_to payload for Matrix thread replies."""
|
||||
if not metadata:
|
||||
return None
|
||||
root_id = metadata.get("thread_root_event_id")
|
||||
if not isinstance(root_id, str) or not root_id:
|
||||
return None
|
||||
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||
if not isinstance(reply_to, str) or not reply_to:
|
||||
return None
|
||||
return {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": root_id,
|
||||
"m.in_reply_to": {"event_id": reply_to},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
|
||||
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||
"""Map Matrix event payload/type to a stable attachment kind."""
|
||||
msgtype = self._event_source_content(event).get("msgtype")
|
||||
if msgtype == "m.image":
|
||||
@@ -535,7 +564,7 @@ class MatrixChannel(BaseChannel):
|
||||
return "file"
|
||||
|
||||
@staticmethod
|
||||
def _is_encrypted_media_event(event: Any) -> bool:
|
||||
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||
"""Return True for encrypted Matrix media events."""
|
||||
return (
|
||||
isinstance(getattr(event, "key", None), dict)
|
||||
@@ -543,7 +572,7 @@ class MatrixChannel(BaseChannel):
|
||||
and isinstance(getattr(event, "iv", None), str)
|
||||
)
|
||||
|
||||
def _event_declared_size_bytes(self, event: Any) -> int | None:
|
||||
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||
"""Return declared media size from Matrix event info, if present."""
|
||||
info = self._event_source_content(event).get("info")
|
||||
if not isinstance(info, dict):
|
||||
@@ -553,7 +582,7 @@ class MatrixChannel(BaseChannel):
|
||||
return size
|
||||
return None
|
||||
|
||||
def _event_mime(self, event: Any) -> str | None:
|
||||
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||
"""Best-effort MIME extraction from Matrix media event."""
|
||||
info = self._event_source_content(event).get("info")
|
||||
if isinstance(info, dict):
|
||||
@@ -566,7 +595,7 @@ class MatrixChannel(BaseChannel):
|
||||
return mime
|
||||
return None
|
||||
|
||||
def _event_filename(self, event: Any, attachment_type: str) -> str:
|
||||
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||
"""Build a safe filename for a Matrix attachment."""
|
||||
body = getattr(event, "body", None)
|
||||
if isinstance(body, str) and body.strip():
|
||||
@@ -577,7 +606,7 @@ class MatrixChannel(BaseChannel):
|
||||
|
||||
def _build_attachment_path(
|
||||
self,
|
||||
event: Any,
|
||||
event: MatrixMediaEvent,
|
||||
attachment_type: str,
|
||||
filename: str,
|
||||
mime: str | None,
|
||||
@@ -637,7 +666,7 @@ class MatrixChannel(BaseChannel):
|
||||
)
|
||||
return None
|
||||
|
||||
def _decrypt_media_bytes(self, event: Any, ciphertext: bytes) -> bytes | None:
|
||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||
"""Decrypt encrypted Matrix attachment bytes."""
|
||||
key_obj = getattr(event, "key", None)
|
||||
hashes = getattr(event, "hashes", None)
|
||||
@@ -666,7 +695,7 @@ class MatrixChannel(BaseChannel):
|
||||
async def _fetch_media_attachment(
|
||||
self,
|
||||
room: MatrixRoom,
|
||||
event: Any,
|
||||
event: MatrixMediaEvent,
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
"""Download and prepare a Matrix attachment for inbound processing."""
|
||||
attachment_type = self._event_attachment_type(event)
|
||||
@@ -683,10 +712,7 @@ class MatrixChannel(BaseChannel):
|
||||
return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename)
|
||||
|
||||
declared_size = self._event_declared_size_bytes(event)
|
||||
if (
|
||||
declared_size is not None
|
||||
and declared_size > self.config.max_inbound_media_bytes
|
||||
):
|
||||
if declared_size is not None and declared_size > self.config.max_inbound_media_bytes:
|
||||
logger.warning(
|
||||
"Matrix attachment skipped in room {}: declared size {} exceeds limit {}",
|
||||
room.room_id,
|
||||
@@ -765,7 +791,7 @@ class MatrixChannel(BaseChannel):
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
|
||||
async def _on_media_message(self, room: MatrixRoom, event: Any) -> None:
|
||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||
"""Handle inbound Matrix media events and forward local attachment paths."""
|
||||
if event.sender == self.config.user_id:
|
||||
return
|
||||
|
||||
@@ -159,6 +159,23 @@ async def test_start_skips_load_store_when_device_id_missing(
|
||||
await channel.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_event_callbacks_uses_media_base_filter() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
channel._register_event_callbacks()
|
||||
|
||||
assert len(client.callbacks) == 3
|
||||
assert client.callbacks[1][0] == channel._on_media_message
|
||||
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
|
||||
|
||||
|
||||
def test_media_event_filter_does_not_match_text_events() -> None:
|
||||
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_disables_e2ee_when_configured(
|
||||
monkeypatch, tmp_path
|
||||
|
||||
Reference in New Issue
Block a user