diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 7129818..9e58f27 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -26,6 +26,7 @@ from nanobot.agent.i18n import ( from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool +from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.registry import ToolRegistry @@ -177,7 +178,9 @@ class AgentLoop: def _register_default_tools(self) -> None: """Register the default set of tools.""" allowed_dir = self.workspace if self.restrict_to_workspace else None - for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool): + extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None + self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) + for cls in (WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(ExecTool( working_dir=str(self.workspace), diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index bee90a4..ed06138 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -8,6 +8,7 @@ from typing import Any from loguru import logger +from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.shell import ExecTool @@ -97,7 +98,8 @@ class SubagentManager: # Build subagent tools (no message tool, no spawn tool) tools = ToolRegistry() allowed_dir = self.workspace if self.restrict_to_workspace else None - tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) + extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None + tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 02c8331..6443f28 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -8,7 +8,10 @@ from nanobot.agent.tools.base import Tool def _resolve_path( - path: str, workspace: Path | None = None, allowed_dir: Path | None = None + path: str, + workspace: Path | None = None, + allowed_dir: Path | None = None, + extra_allowed_dirs: list[Path] | None = None, ) -> Path: """Resolve path against workspace (if relative) and enforce directory restriction.""" p = Path(path).expanduser() @@ -16,22 +19,35 @@ def _resolve_path( p = workspace / p resolved = p.resolve() if allowed_dir: - try: - resolved.relative_to(allowed_dir.resolve()) - except ValueError: + all_dirs = [allowed_dir] + (extra_allowed_dirs or []) + if not any(_is_under(resolved, d) for d in all_dirs): raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") return resolved +def _is_under(path: Path, directory: Path) -> bool: + try: + path.relative_to(directory.resolve()) + return True + except ValueError: + return False + + class _FsTool(Tool): """Shared base for filesystem tools — common init and path resolution.""" - def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): + def __init__( + self, + workspace: Path | None = None, + allowed_dir: Path | None = None, + extra_allowed_dirs: list[Path] | None = None, + ): self._workspace = workspace self._allowed_dir = allowed_dir + self._extra_allowed_dirs = extra_allowed_dirs def _resolve(self, path: str) -> Path: - return _resolve_path(path, self._workspace, self._allowed_dir) + return _resolve_path(path, self._workspace, self._allowed_dir, self._extra_allowed_dirs) # --------------------------------------------------------------------------- diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 5a32155..d15896f 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -62,6 +62,49 @@ class NanobotDingTalkHandler(CallbackHandler): if not content: content = message.data.get("text", {}).get("content", "").strip() + # Handle file/image messages + file_paths = [] + if chatbot_msg.message_type == "picture" and chatbot_msg.image_content: + download_code = chatbot_msg.image_content.download_code + if download_code: + sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown" + fp = await self.channel._download_dingtalk_file(download_code, "image.jpg", sender_uid) + if fp: + file_paths.append(fp) + content = content or "[Image]" + + elif chatbot_msg.message_type == "file": + download_code = message.data.get("content", {}).get("downloadCode") or message.data.get("downloadCode") + fname = message.data.get("content", {}).get("fileName") or message.data.get("fileName") or "file" + if download_code: + sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown" + fp = await self.channel._download_dingtalk_file(download_code, fname, sender_uid) + if fp: + file_paths.append(fp) + content = content or "[File]" + + elif chatbot_msg.message_type == "richText" and chatbot_msg.rich_text_content: + rich_list = chatbot_msg.rich_text_content.rich_text_list or [] + for item in rich_list: + if not isinstance(item, dict): + continue + if item.get("type") == "text": + t = item.get("text", "").strip() + if t: + content = (content + " " + t).strip() if content else t + elif item.get("downloadCode"): + dc = item["downloadCode"] + fname = item.get("fileName") or "file" + sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown" + fp = await self.channel._download_dingtalk_file(dc, fname, sender_uid) + if fp: + file_paths.append(fp) + content = content or "[File]" + + if file_paths: + file_list = "\n".join("- " + p for p in file_paths) + content = content + "\n\nReceived files:\n" + file_list + if not content: logger.warning( "Received empty or unsupported message type: {}", @@ -472,3 +515,50 @@ class DingTalkChannel(BaseChannel): ) except Exception as e: logger.error("Error publishing DingTalk message: {}", e) + + async def _download_dingtalk_file( + self, + download_code: str, + filename: str, + sender_id: str, + ) -> str | None: + """Download a DingTalk file to the media directory, return local path.""" + from nanobot.config.paths import get_media_dir + + try: + token = await self._get_access_token() + if not token or not self._http: + logger.error("DingTalk file download: no token or http client") + return None + + # Step 1: Exchange downloadCode for a temporary download URL + api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download" + headers = {"x-acs-dingtalk-access-token": token, "Content-Type": "application/json"} + payload = {"downloadCode": download_code, "robotCode": self.config.client_id} + resp = await self._http.post(api_url, json=payload, headers=headers) + if resp.status_code != 200: + logger.error("DingTalk get download URL failed: status={}, body={}", resp.status_code, resp.text) + return None + + result = resp.json() + download_url = result.get("downloadUrl") + if not download_url: + logger.error("DingTalk download URL not found in response: {}", result) + return None + + # Step 2: Download the file content + file_resp = await self._http.get(download_url, follow_redirects=True) + if file_resp.status_code != 200: + logger.error("DingTalk file download failed: status={}", file_resp.status_code) + return None + + # Save to media directory (accessible under workspace) + download_dir = get_media_dir("dingtalk") / sender_id + download_dir.mkdir(parents=True, exist_ok=True) + file_path = download_dir / filename + await asyncio.to_thread(file_path.write_bytes, file_resp.content) + logger.info("DingTalk file saved: {}", file_path) + return str(file_path) + except Exception as e: + logger.error("DingTalk file download error: {}", e) + return None diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 114a948..8b6956c 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -89,6 +89,14 @@ class LLMProvider(ABC): "server error", "temporarily unavailable", ) + _IMAGE_UNSUPPORTED_MARKERS = ( + "image_url is only supported", + "does not support image", + "images are not supported", + "image input is not supported", + "image_url is not supported", + "unsupported image input", + ) _SENTINEL = object() @@ -189,6 +197,40 @@ class LLMProvider(ABC): err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) + @classmethod + def _is_image_unsupported_error(cls, content: str | None) -> bool: + err = (content or "").lower() + return any(marker in err for marker in cls._IMAGE_UNSUPPORTED_MARKERS) + + @staticmethod + def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + """Replace image_url blocks with text placeholder. Returns None if no images found.""" + found = False + result = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + new_content = [] + for b in content: + if isinstance(b, dict) and b.get("type") == "image_url": + new_content.append({"type": "text", "text": "[image omitted]"}) + found = True + else: + new_content.append(b) + result.append({**msg, "content": new_content}) + else: + result.append(msg) + return result if found else None + + async def _safe_chat(self, **kwargs: Any) -> LLMResponse: + """Call chat() and convert unexpected exceptions to error responses.""" + try: + return await self.chat(**kwargs) + except asyncio.CancelledError: + raise + except Exception as exc: + return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") + async def chat_with_retry( self, messages: list[dict[str, Any]], @@ -212,57 +254,34 @@ class LLMProvider(ABC): if reasoning_effort is self._SENTINEL: reasoning_effort = self.generation.reasoning_effort + kw: dict[str, Any] = dict( + messages=messages, tools=tools, model=model, + max_tokens=max_tokens, temperature=temperature, + reasoning_effort=reasoning_effort, tool_choice=tool_choice, + ) + for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): - try: - response = await self.chat( - messages=messages, - tools=tools, - model=model, - max_tokens=max_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice, - ) - except asyncio.CancelledError: - raise - except Exception as exc: - response = LLMResponse( - content=f"Error calling LLM: {exc}", - finish_reason="error", - ) + response = await self._safe_chat(**kw) if response.finish_reason != "error": return response + if not self._is_transient_error(response.content): + if self._is_image_unsupported_error(response.content): + stripped = self._strip_image_content(messages) + if stripped is not None: + logger.warning("Model does not support image input, retrying without images") + return await self._safe_chat(**{**kw, "messages": stripped}) return response - err = (response.content or "").lower() logger.warning( "LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt, - len(self._CHAT_RETRY_DELAYS), - delay, - err[:120], + attempt, len(self._CHAT_RETRY_DELAYS), delay, + (response.content or "")[:120].lower(), ) await asyncio.sleep(delay) - try: - return await self.chat( - messages=messages, - tools=tools, - model=model, - max_tokens=max_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice, - ) - except asyncio.CancelledError: - raise - except Exception as exc: - return LLMResponse( - content=f"Error calling LLM: {exc}", - finish_reason="error", - ) + return await self._safe_chat(**kw) @abstractmethod def get_default_model(self) -> str: diff --git a/tests/test_dingtalk_channel.py b/tests/test_dingtalk_channel.py index 6051014..1d36ee8 100644 --- a/tests/test_dingtalk_channel.py +++ b/tests/test_dingtalk_channel.py @@ -14,19 +14,31 @@ class _FakeResponse: self.status_code = status_code self._json_body = json_body or {} self.text = "{}" + self.content = b"" + self.headers = {"content-type": "application/json"} def json(self) -> dict: return self._json_body class _FakeHttp: - def __init__(self) -> None: + def __init__(self, responses: list[_FakeResponse] | None = None) -> None: self.calls: list[dict] = [] + self._responses = list(responses) if responses else [] - async def post(self, url: str, json=None, headers=None): - self.calls.append({"url": url, "json": json, "headers": headers}) + def _next_response(self) -> _FakeResponse: + if self._responses: + return self._responses.pop(0) return _FakeResponse() + async def post(self, url: str, json=None, headers=None, **kwargs): + self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers}) + return self._next_response() + + async def get(self, url: str, **kwargs): + self.calls.append({"method": "GET", "url": url}) + return self._next_response() + @pytest.mark.asyncio async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None: @@ -109,3 +121,93 @@ async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatc assert msg.content == "voice transcript" assert msg.sender_id == "user1" assert msg.chat_id == "group:conv123" + + +@pytest.mark.asyncio +async def test_handler_processes_file_message(monkeypatch) -> None: + """Test that file messages are handled and forwarded with downloaded path.""" + bus = MessageBus() + channel = DingTalkChannel( + DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]), + bus, + ) + handler = NanobotDingTalkHandler(channel) + + class _FakeFileChatbotMessage: + text = None + extensions = {} + image_content = None + rich_text_content = None + sender_staff_id = "user1" + sender_id = "fallback-user" + sender_nick = "Alice" + message_type = "file" + + @staticmethod + def from_dict(_data): + return _FakeFileChatbotMessage() + + async def fake_download(download_code, filename, sender_id): + return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}" + + monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage) + monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK")) + monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download) + + status, body = await handler.process( + SimpleNamespace( + data={ + "conversationType": "1", + "content": {"downloadCode": "abc123", "fileName": "report.xlsx"}, + "text": {"content": ""}, + } + ) + ) + + await asyncio.gather(*list(channel._background_tasks)) + msg = await bus.consume_inbound() + + assert (status, body) == ("OK", "OK") + assert "[File]" in msg.content + assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content + + +@pytest.mark.asyncio +async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None: + """Test the two-step file download flow (get URL then download content).""" + channel = DingTalkChannel( + DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]), + MessageBus(), + ) + + # Mock access token + async def fake_get_token(): + return "test-token" + + monkeypatch.setattr(channel, "_get_access_token", fake_get_token) + + # Mock HTTP: first POST returns downloadUrl, then GET returns file bytes + file_content = b"fake file content" + channel._http = _FakeHttp(responses=[ + _FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}), + _FakeResponse(200), + ]) + channel._http._responses[1].content = file_content + + # Redirect media dir to tmp_path + monkeypatch.setattr( + "nanobot.config.paths.get_media_dir", + lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path, + ) + + result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1") + + assert result is not None + assert result.endswith("test.xlsx") + assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content + + # Verify API calls + assert channel._http.calls[0]["method"] == "POST" + assert "messageFiles/download" in channel._http.calls[0]["url"] + assert channel._http.calls[0]["json"]["downloadCode"] == "code123" + assert channel._http.calls[1]["method"] == "GET" diff --git a/tests/test_filesystem_tools.py b/tests/test_filesystem_tools.py index db8f256..8c1a5e8 100644 --- a/tests/test_filesystem_tools.py +++ b/tests/test_filesystem_tools.py @@ -249,3 +249,114 @@ class TestListDirTool: result = await tool.execute(path=str(tmp_path / "nope")) assert "Error" in result assert "not found" in result + + +# --------------------------------------------------------------------------- +# Workspace restriction + extra_allowed_dirs +# --------------------------------------------------------------------------- + +class TestWorkspaceRestriction: + + @pytest.mark.asyncio + async def test_read_blocked_outside_workspace(self, tmp_path): + workspace = tmp_path / "ws" + workspace.mkdir() + outside = tmp_path / "outside" + outside.mkdir() + secret = outside / "secret.txt" + secret.write_text("top secret") + + tool = ReadFileTool(workspace=workspace, allowed_dir=workspace) + result = await tool.execute(path=str(secret)) + assert "Error" in result + assert "outside" in result.lower() + + @pytest.mark.asyncio + async def test_read_allowed_with_extra_dir(self, tmp_path): + workspace = tmp_path / "ws" + workspace.mkdir() + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + skill_file = skills_dir / "test_skill" / "SKILL.md" + skill_file.parent.mkdir() + skill_file.write_text("# Test Skill\nDo something.") + + tool = ReadFileTool( + workspace=workspace, allowed_dir=workspace, + extra_allowed_dirs=[skills_dir], + ) + result = await tool.execute(path=str(skill_file)) + assert "Test Skill" in result + assert "Error" not in result + + @pytest.mark.asyncio + async def test_extra_dirs_does_not_widen_write(self, tmp_path): + from nanobot.agent.tools.filesystem import WriteFileTool + + workspace = tmp_path / "ws" + workspace.mkdir() + outside = tmp_path / "outside" + outside.mkdir() + + tool = WriteFileTool(workspace=workspace, allowed_dir=workspace) + result = await tool.execute(path=str(outside / "hack.txt"), content="pwned") + assert "Error" in result + assert "outside" in result.lower() + + @pytest.mark.asyncio + async def test_read_still_blocked_for_unrelated_dir(self, tmp_path): + workspace = tmp_path / "ws" + workspace.mkdir() + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + unrelated = tmp_path / "other" + unrelated.mkdir() + secret = unrelated / "secret.txt" + secret.write_text("nope") + + tool = ReadFileTool( + workspace=workspace, allowed_dir=workspace, + extra_allowed_dirs=[skills_dir], + ) + result = await tool.execute(path=str(secret)) + assert "Error" in result + assert "outside" in result.lower() + + @pytest.mark.asyncio + async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path): + """Adding extra_allowed_dirs must not break normal workspace reads.""" + workspace = tmp_path / "ws" + workspace.mkdir() + ws_file = workspace / "README.md" + ws_file.write_text("hello from workspace") + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + + tool = ReadFileTool( + workspace=workspace, allowed_dir=workspace, + extra_allowed_dirs=[skills_dir], + ) + result = await tool.execute(path=str(ws_file)) + assert "hello from workspace" in result + assert "Error" not in result + + @pytest.mark.asyncio + async def test_edit_blocked_in_extra_dir(self, tmp_path): + """edit_file must not be able to modify files in extra_allowed_dirs.""" + workspace = tmp_path / "ws" + workspace.mkdir() + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + skill_file = skills_dir / "weather" / "SKILL.md" + skill_file.parent.mkdir() + skill_file.write_text("# Weather\nOriginal content.") + + tool = EditFileTool(workspace=workspace, allowed_dir=workspace) + result = await tool.execute( + path=str(skill_file), + old_text="Original content.", + new_text="Hacked content.", + ) + assert "Error" in result + assert "outside" in result.lower() + assert skill_file.read_text() == "# Weather\nOriginal content." diff --git a/tests/test_provider_retry.py b/tests/test_provider_retry.py index 2420399..6f2c165 100644 --- a/tests/test_provider_retry.py +++ b/tests/test_provider_retry.py @@ -123,3 +123,87 @@ async def test_chat_with_retry_explicit_override_beats_defaults() -> None: assert provider.last_kwargs["temperature"] == 0.9 assert provider.last_kwargs["max_tokens"] == 9999 assert provider.last_kwargs["reasoning_effort"] == "low" + + +# --------------------------------------------------------------------------- +# Image-unsupported fallback tests +# --------------------------------------------------------------------------- + +_IMAGE_MSG = [ + {"role": "user", "content": [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ]}, +] + + +@pytest.mark.asyncio +async def test_image_unsupported_error_retries_without_images() -> None: + """If the model rejects image_url, retry once with images stripped.""" + provider = ScriptedProvider([ + LLMResponse( + content="Invalid content type. image_url is only supported by certain models", + finish_reason="error", + ), + LLMResponse(content="ok, no image"), + ]) + + response = await provider.chat_with_retry(messages=_IMAGE_MSG) + + assert response.content == "ok, no image" + assert provider.calls == 2 + msgs_on_retry = provider.last_kwargs["messages"] + for msg in msgs_on_retry: + content = msg.get("content") + if isinstance(content, list): + assert all(b.get("type") != "image_url" for b in content) + assert any("[image omitted]" in (b.get("text") or "") for b in content) + + +@pytest.mark.asyncio +async def test_image_unsupported_error_no_retry_without_image_content() -> None: + """If messages don't contain image_url blocks, don't retry on image error.""" + provider = ScriptedProvider([ + LLMResponse( + content="image_url is only supported by certain models", + finish_reason="error", + ), + ]) + + response = await provider.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + ) + + assert provider.calls == 1 + assert response.finish_reason == "error" + + +@pytest.mark.asyncio +async def test_image_unsupported_fallback_returns_error_on_second_failure() -> None: + """If the image-stripped retry also fails, return that error.""" + provider = ScriptedProvider([ + LLMResponse( + content="does not support image input", + finish_reason="error", + ), + LLMResponse(content="some other error", finish_reason="error"), + ]) + + response = await provider.chat_with_retry(messages=_IMAGE_MSG) + + assert provider.calls == 2 + assert response.content == "some other error" + assert response.finish_reason == "error" + + +@pytest.mark.asyncio +async def test_non_image_error_does_not_trigger_image_fallback() -> None: + """Regular non-transient errors must not trigger image stripping.""" + provider = ScriptedProvider([ + LLMResponse(content="401 unauthorized", finish_reason="error"), + ]) + + response = await provider.chat_with_retry(messages=_IMAGE_MSG) + + assert provider.calls == 1 + assert response.content == "401 unauthorized"