refactor(matrix): use base media event filter for callbacks
- Replaces the explicit media event tuple with MATRIX_MEDIA_EVENT_FILTER based on media base classes: (RoomMessageMedia, RoomEncryptedMedia). - Keeps MatrixMediaEvent as the static typing alias for media-specific handlers. - Removes MatrixInboundEvent and uses RoomMessage in mention-related logic. - Adds regression tests for: - callback registration using MATRIX_MEDIA_EVENT_FILTER - ensuring RoomMessageText is not matched by the media filter.
This commit is contained in:
committed by
Alexander Minges
parent
1103f000fc
commit
10de3bf329
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
import nh3
|
import nh3
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -15,15 +15,10 @@ from nio import (
|
|||||||
JoinError,
|
JoinError,
|
||||||
MatrixRoom,
|
MatrixRoom,
|
||||||
MemoryDownloadResponse,
|
MemoryDownloadResponse,
|
||||||
RoomEncryptedAudio,
|
RoomEncryptedMedia,
|
||||||
RoomEncryptedFile,
|
RoomMessage,
|
||||||
RoomEncryptedImage,
|
RoomMessageMedia,
|
||||||
RoomEncryptedVideo,
|
|
||||||
RoomMessageAudio,
|
|
||||||
RoomMessageFile,
|
|
||||||
RoomMessageImage,
|
|
||||||
RoomMessageText,
|
RoomMessageText,
|
||||||
RoomMessageVideo,
|
|
||||||
RoomSendError,
|
RoomSendError,
|
||||||
RoomTypingError,
|
RoomTypingError,
|
||||||
SyncError,
|
SyncError,
|
||||||
@@ -51,16 +46,10 @@ MATRIX_ATTACHMENT_TOO_LARGE_TEMPLATE = "[attachment: {} - too large]"
|
|||||||
MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]"
|
MATRIX_ATTACHMENT_FAILED_TEMPLATE = "[attachment: {} - download failed]"
|
||||||
MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment"
|
MATRIX_DEFAULT_ATTACHMENT_NAME = "attachment"
|
||||||
|
|
||||||
MATRIX_MEDIA_EVENT_TYPES = (
|
# Runtime callback filter for nio event dispatch (checked via isinstance).
|
||||||
RoomMessageImage,
|
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||||
RoomMessageFile,
|
# Static typing alias for media-specific handlers/helpers.
|
||||||
RoomMessageAudio,
|
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||||
RoomMessageVideo,
|
|
||||||
RoomEncryptedImage,
|
|
||||||
RoomEncryptedFile,
|
|
||||||
RoomEncryptedAudio,
|
|
||||||
RoomEncryptedVideo,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Markdown renderer policy:
|
# Markdown renderer policy:
|
||||||
# https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes
|
# https://spec.matrix.org/v1.17/client-server-api/#mroommessage-msgtypes
|
||||||
@@ -345,7 +334,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
def _register_event_callbacks(self) -> None:
|
def _register_event_callbacks(self) -> None:
|
||||||
"""Register Matrix event callbacks used by this channel."""
|
"""Register Matrix event callbacks used by this channel."""
|
||||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_TYPES)
|
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||||
|
|
||||||
def _register_response_callbacks(self) -> None:
|
def _register_response_callbacks(self) -> None:
|
||||||
@@ -460,7 +449,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
member_count = getattr(room, "member_count", None)
|
member_count = getattr(room, "member_count", None)
|
||||||
return isinstance(member_count, int) and member_count <= 2
|
return isinstance(member_count, int) and member_count <= 2
|
||||||
|
|
||||||
def _is_bot_mentioned_from_mx_mentions(self, event: Any) -> bool:
|
def _is_bot_mentioned_from_mx_mentions(self, event: RoomMessage) -> bool:
|
||||||
"""Resolve mentions strictly from Matrix-native m.mentions payload."""
|
"""Resolve mentions strictly from Matrix-native m.mentions payload."""
|
||||||
source = getattr(event, "source", None)
|
source = getattr(event, "source", None)
|
||||||
if not isinstance(source, dict):
|
if not isinstance(source, dict):
|
||||||
@@ -480,7 +469,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
|
|
||||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||||
|
|
||||||
def _should_process_message(self, room: MatrixRoom, event: Any) -> bool:
|
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||||
"""Apply sender and room policy checks before processing Matrix messages."""
|
"""Apply sender and room policy checks before processing Matrix messages."""
|
||||||
if not self.is_allowed(event.sender):
|
if not self.is_allowed(event.sender):
|
||||||
return False
|
return False
|
||||||
@@ -505,7 +494,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return media_dir
|
return media_dir
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _event_source_content(event: Any) -> dict[str, Any]:
|
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||||
"""Extract Matrix event content payload when available."""
|
"""Extract Matrix event content payload when available."""
|
||||||
source = getattr(event, "source", None)
|
source = getattr(event, "source", None)
|
||||||
if not isinstance(source, dict):
|
if not isinstance(source, dict):
|
||||||
@@ -513,7 +502,47 @@ class MatrixChannel(BaseChannel):
|
|||||||
content = source.get("content")
|
content = source.get("content")
|
||||||
return content if isinstance(content, dict) else {}
|
return content if isinstance(content, dict) else {}
|
||||||
|
|
||||||
def _event_attachment_type(self, event: Any) -> str:
|
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||||
|
"""Return thread root event_id if this message is inside a thread."""
|
||||||
|
content = self._event_source_content(event)
|
||||||
|
relates_to = content.get("m.relates_to")
|
||||||
|
if not isinstance(relates_to, dict):
|
||||||
|
return None
|
||||||
|
if relates_to.get("rel_type") != "m.thread":
|
||||||
|
return None
|
||||||
|
root_id = relates_to.get("event_id")
|
||||||
|
return root_id if isinstance(root_id, str) and root_id else None
|
||||||
|
|
||||||
|
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||||
|
"""Build metadata used to reply within a thread."""
|
||||||
|
root_id = self._event_thread_root_id(event)
|
||||||
|
if not root_id:
|
||||||
|
return None
|
||||||
|
reply_to = getattr(event, "event_id", None)
|
||||||
|
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||||
|
if isinstance(reply_to, str) and reply_to:
|
||||||
|
meta["thread_reply_to_event_id"] = reply_to
|
||||||
|
return meta
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
"""Build m.relates_to payload for Matrix thread replies."""
|
||||||
|
if not metadata:
|
||||||
|
return None
|
||||||
|
root_id = metadata.get("thread_root_event_id")
|
||||||
|
if not isinstance(root_id, str) or not root_id:
|
||||||
|
return None
|
||||||
|
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||||
|
if not isinstance(reply_to, str) or not reply_to:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"rel_type": "m.thread",
|
||||||
|
"event_id": root_id,
|
||||||
|
"m.in_reply_to": {"event_id": reply_to},
|
||||||
|
"is_falling_back": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||||
"""Map Matrix event payload/type to a stable attachment kind."""
|
"""Map Matrix event payload/type to a stable attachment kind."""
|
||||||
msgtype = self._event_source_content(event).get("msgtype")
|
msgtype = self._event_source_content(event).get("msgtype")
|
||||||
if msgtype == "m.image":
|
if msgtype == "m.image":
|
||||||
@@ -535,7 +564,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return "file"
|
return "file"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_encrypted_media_event(event: Any) -> bool:
|
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||||
"""Return True for encrypted Matrix media events."""
|
"""Return True for encrypted Matrix media events."""
|
||||||
return (
|
return (
|
||||||
isinstance(getattr(event, "key", None), dict)
|
isinstance(getattr(event, "key", None), dict)
|
||||||
@@ -543,7 +572,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
and isinstance(getattr(event, "iv", None), str)
|
and isinstance(getattr(event, "iv", None), str)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _event_declared_size_bytes(self, event: Any) -> int | None:
|
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||||
"""Return declared media size from Matrix event info, if present."""
|
"""Return declared media size from Matrix event info, if present."""
|
||||||
info = self._event_source_content(event).get("info")
|
info = self._event_source_content(event).get("info")
|
||||||
if not isinstance(info, dict):
|
if not isinstance(info, dict):
|
||||||
@@ -553,7 +582,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return size
|
return size
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _event_mime(self, event: Any) -> str | None:
|
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||||
"""Best-effort MIME extraction from Matrix media event."""
|
"""Best-effort MIME extraction from Matrix media event."""
|
||||||
info = self._event_source_content(event).get("info")
|
info = self._event_source_content(event).get("info")
|
||||||
if isinstance(info, dict):
|
if isinstance(info, dict):
|
||||||
@@ -566,7 +595,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return mime
|
return mime
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _event_filename(self, event: Any, attachment_type: str) -> str:
|
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||||
"""Build a safe filename for a Matrix attachment."""
|
"""Build a safe filename for a Matrix attachment."""
|
||||||
body = getattr(event, "body", None)
|
body = getattr(event, "body", None)
|
||||||
if isinstance(body, str) and body.strip():
|
if isinstance(body, str) and body.strip():
|
||||||
@@ -577,7 +606,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
|
|
||||||
def _build_attachment_path(
|
def _build_attachment_path(
|
||||||
self,
|
self,
|
||||||
event: Any,
|
event: MatrixMediaEvent,
|
||||||
attachment_type: str,
|
attachment_type: str,
|
||||||
filename: str,
|
filename: str,
|
||||||
mime: str | None,
|
mime: str | None,
|
||||||
@@ -637,7 +666,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _decrypt_media_bytes(self, event: Any, ciphertext: bytes) -> bytes | None:
|
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||||
"""Decrypt encrypted Matrix attachment bytes."""
|
"""Decrypt encrypted Matrix attachment bytes."""
|
||||||
key_obj = getattr(event, "key", None)
|
key_obj = getattr(event, "key", None)
|
||||||
hashes = getattr(event, "hashes", None)
|
hashes = getattr(event, "hashes", None)
|
||||||
@@ -666,7 +695,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
async def _fetch_media_attachment(
|
async def _fetch_media_attachment(
|
||||||
self,
|
self,
|
||||||
room: MatrixRoom,
|
room: MatrixRoom,
|
||||||
event: Any,
|
event: MatrixMediaEvent,
|
||||||
) -> tuple[dict[str, Any] | None, str]:
|
) -> tuple[dict[str, Any] | None, str]:
|
||||||
"""Download and prepare a Matrix attachment for inbound processing."""
|
"""Download and prepare a Matrix attachment for inbound processing."""
|
||||||
attachment_type = self._event_attachment_type(event)
|
attachment_type = self._event_attachment_type(event)
|
||||||
@@ -683,10 +712,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename)
|
return None, MATRIX_ATTACHMENT_FAILED_TEMPLATE.format(filename)
|
||||||
|
|
||||||
declared_size = self._event_declared_size_bytes(event)
|
declared_size = self._event_declared_size_bytes(event)
|
||||||
if (
|
if declared_size is not None and declared_size > self.config.max_inbound_media_bytes:
|
||||||
declared_size is not None
|
|
||||||
and declared_size > self.config.max_inbound_media_bytes
|
|
||||||
):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Matrix attachment skipped in room {}: declared size {} exceeds limit {}",
|
"Matrix attachment skipped in room {}: declared size {} exceeds limit {}",
|
||||||
room.room_id,
|
room.room_id,
|
||||||
@@ -765,7 +791,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _on_media_message(self, room: MatrixRoom, event: Any) -> None:
|
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||||
"""Handle inbound Matrix media events and forward local attachment paths."""
|
"""Handle inbound Matrix media events and forward local attachment paths."""
|
||||||
if event.sender == self.config.user_id:
|
if event.sender == self.config.user_id:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -159,6 +159,23 @@ async def test_start_skips_load_store_when_device_id_missing(
|
|||||||
await channel.stop()
|
await channel.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_event_callbacks_uses_media_base_filter() -> None:
|
||||||
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
channel._register_event_callbacks()
|
||||||
|
|
||||||
|
assert len(client.callbacks) == 3
|
||||||
|
assert client.callbacks[1][0] == channel._on_media_message
|
||||||
|
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
|
||||||
|
|
||||||
|
|
||||||
|
def test_media_event_filter_does_not_match_text_events() -> None:
|
||||||
|
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_disables_e2ee_when_configured(
|
async def test_start_disables_e2ee_when_configured(
|
||||||
monkeypatch, tmp_path
|
monkeypatch, tmp_path
|
||||||
|
|||||||
Reference in New Issue
Block a user