style: unify code formatting and import order
- Remove trailing whitespace and normalize blank lines - Unify string quotes and line breaks for long lines - Sort imports alphabetically across modules
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""Agent core module."""
|
"""Agent core module."""
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
|
|
||||||
|
|||||||
@@ -14,15 +14,15 @@ from nanobot.agent.skills import SkillsLoader
|
|||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
"""Builds the context (system prompt + messages) for the agent."""
|
"""Builds the context (system prompt + messages) for the agent."""
|
||||||
|
|
||||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
||||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.memory = MemoryStore(workspace)
|
self.memory = MemoryStore(workspace)
|
||||||
self.skills = SkillsLoader(workspace)
|
self.skills = SkillsLoader(workspace)
|
||||||
|
|
||||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||||
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||||
parts = [self._get_identity()]
|
parts = [self._get_identity()]
|
||||||
@@ -51,13 +51,13 @@ Skills with available="false" need dependencies installed first - you can try in
|
|||||||
{skills_summary}""")
|
{skills_summary}""")
|
||||||
|
|
||||||
return "\n\n---\n\n".join(parts)
|
return "\n\n---\n\n".join(parts)
|
||||||
|
|
||||||
def _get_identity(self) -> str:
|
def _get_identity(self) -> str:
|
||||||
"""Get the core identity section."""
|
"""Get the core identity section."""
|
||||||
workspace_path = str(self.workspace.expanduser().resolve())
|
workspace_path = str(self.workspace.expanduser().resolve())
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||||
|
|
||||||
return f"""# nanobot 🐈
|
return f"""# nanobot 🐈
|
||||||
|
|
||||||
You are nanobot, a helpful AI assistant.
|
You are nanobot, a helpful AI assistant.
|
||||||
@@ -89,19 +89,19 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
if channel and chat_id:
|
if channel and chat_id:
|
||||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
for filename in self.BOOTSTRAP_FILES:
|
for filename in self.BOOTSTRAP_FILES:
|
||||||
file_path = self.workspace / filename
|
file_path = self.workspace / filename
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
content = file_path.read_text(encoding="utf-8")
|
content = file_path.read_text(encoding="utf-8")
|
||||||
parts.append(f"## {filename}\n\n{content}")
|
parts.append(f"## {filename}\n\n{content}")
|
||||||
|
|
||||||
return "\n\n".join(parts) if parts else ""
|
return "\n\n".join(parts) if parts else ""
|
||||||
|
|
||||||
def build_messages(
|
def build_messages(
|
||||||
self,
|
self,
|
||||||
history: list[dict[str, Any]],
|
history: list[dict[str, Any]],
|
||||||
@@ -123,7 +123,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
"""Build user message content with optional base64-encoded images."""
|
"""Build user message content with optional base64-encoded images."""
|
||||||
if not media:
|
if not media:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
for path in media:
|
for path in media:
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
@@ -132,11 +132,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
continue
|
continue
|
||||||
b64 = base64.b64encode(p.read_bytes()).decode()
|
b64 = base64.b64encode(p.read_bytes()).decode()
|
||||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||||
|
|
||||||
if not images:
|
if not images:
|
||||||
return text
|
return text
|
||||||
return images + [{"type": "text", "text": text}]
|
return images + [{"type": "text", "text": text}]
|
||||||
|
|
||||||
def add_tool_result(
|
def add_tool_result(
|
||||||
self, messages: list[dict[str, Any]],
|
self, messages: list[dict[str, Any]],
|
||||||
tool_call_id: str, tool_name: str, result: str,
|
tool_call_id: str, tool_name: str, result: str,
|
||||||
@@ -144,7 +144,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
"""Add a tool result to the message list."""
|
"""Add a tool result to the message list."""
|
||||||
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
|
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_assistant_message(
|
def add_assistant_message(
|
||||||
self, messages: list[dict[str, Any]],
|
self, messages: list[dict[str, Any]],
|
||||||
content: str | None,
|
content: str | None,
|
||||||
|
|||||||
@@ -13,28 +13,28 @@ BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
|
|||||||
class SkillsLoader:
|
class SkillsLoader:
|
||||||
"""
|
"""
|
||||||
Loader for agent skills.
|
Loader for agent skills.
|
||||||
|
|
||||||
Skills are markdown files (SKILL.md) that teach the agent how to use
|
Skills are markdown files (SKILL.md) that teach the agent how to use
|
||||||
specific tools or perform certain tasks.
|
specific tools or perform certain tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
|
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.workspace_skills = workspace / "skills"
|
self.workspace_skills = workspace / "skills"
|
||||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||||
|
|
||||||
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
|
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
List all available skills.
|
List all available skills.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filter_unavailable: If True, filter out skills with unmet requirements.
|
filter_unavailable: If True, filter out skills with unmet requirements.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of skill info dicts with 'name', 'path', 'source'.
|
List of skill info dicts with 'name', 'path', 'source'.
|
||||||
"""
|
"""
|
||||||
skills = []
|
skills = []
|
||||||
|
|
||||||
# Workspace skills (highest priority)
|
# Workspace skills (highest priority)
|
||||||
if self.workspace_skills.exists():
|
if self.workspace_skills.exists():
|
||||||
for skill_dir in self.workspace_skills.iterdir():
|
for skill_dir in self.workspace_skills.iterdir():
|
||||||
@@ -42,7 +42,7 @@ class SkillsLoader:
|
|||||||
skill_file = skill_dir / "SKILL.md"
|
skill_file = skill_dir / "SKILL.md"
|
||||||
if skill_file.exists():
|
if skill_file.exists():
|
||||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
|
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
|
||||||
|
|
||||||
# Built-in skills
|
# Built-in skills
|
||||||
if self.builtin_skills and self.builtin_skills.exists():
|
if self.builtin_skills and self.builtin_skills.exists():
|
||||||
for skill_dir in self.builtin_skills.iterdir():
|
for skill_dir in self.builtin_skills.iterdir():
|
||||||
@@ -50,19 +50,19 @@ class SkillsLoader:
|
|||||||
skill_file = skill_dir / "SKILL.md"
|
skill_file = skill_dir / "SKILL.md"
|
||||||
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
||||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
|
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
|
||||||
|
|
||||||
# Filter by requirements
|
# Filter by requirements
|
||||||
if filter_unavailable:
|
if filter_unavailable:
|
||||||
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
|
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
|
||||||
return skills
|
return skills
|
||||||
|
|
||||||
def load_skill(self, name: str) -> str | None:
|
def load_skill(self, name: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Load a skill by name.
|
Load a skill by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Skill name (directory name).
|
name: Skill name (directory name).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Skill content or None if not found.
|
Skill content or None if not found.
|
||||||
"""
|
"""
|
||||||
@@ -70,22 +70,22 @@ class SkillsLoader:
|
|||||||
workspace_skill = self.workspace_skills / name / "SKILL.md"
|
workspace_skill = self.workspace_skills / name / "SKILL.md"
|
||||||
if workspace_skill.exists():
|
if workspace_skill.exists():
|
||||||
return workspace_skill.read_text(encoding="utf-8")
|
return workspace_skill.read_text(encoding="utf-8")
|
||||||
|
|
||||||
# Check built-in
|
# Check built-in
|
||||||
if self.builtin_skills:
|
if self.builtin_skills:
|
||||||
builtin_skill = self.builtin_skills / name / "SKILL.md"
|
builtin_skill = self.builtin_skills / name / "SKILL.md"
|
||||||
if builtin_skill.exists():
|
if builtin_skill.exists():
|
||||||
return builtin_skill.read_text(encoding="utf-8")
|
return builtin_skill.read_text(encoding="utf-8")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Load specific skills for inclusion in agent context.
|
Load specific skills for inclusion in agent context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
skill_names: List of skill names to load.
|
skill_names: List of skill names to load.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted skills content.
|
Formatted skills content.
|
||||||
"""
|
"""
|
||||||
@@ -95,26 +95,26 @@ class SkillsLoader:
|
|||||||
if content:
|
if content:
|
||||||
content = self._strip_frontmatter(content)
|
content = self._strip_frontmatter(content)
|
||||||
parts.append(f"### Skill: {name}\n\n{content}")
|
parts.append(f"### Skill: {name}\n\n{content}")
|
||||||
|
|
||||||
return "\n\n---\n\n".join(parts) if parts else ""
|
return "\n\n---\n\n".join(parts) if parts else ""
|
||||||
|
|
||||||
def build_skills_summary(self) -> str:
|
def build_skills_summary(self) -> str:
|
||||||
"""
|
"""
|
||||||
Build a summary of all skills (name, description, path, availability).
|
Build a summary of all skills (name, description, path, availability).
|
||||||
|
|
||||||
This is used for progressive loading - the agent can read the full
|
This is used for progressive loading - the agent can read the full
|
||||||
skill content using read_file when needed.
|
skill content using read_file when needed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
XML-formatted skills summary.
|
XML-formatted skills summary.
|
||||||
"""
|
"""
|
||||||
all_skills = self.list_skills(filter_unavailable=False)
|
all_skills = self.list_skills(filter_unavailable=False)
|
||||||
if not all_skills:
|
if not all_skills:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def escape_xml(s: str) -> str:
|
def escape_xml(s: str) -> str:
|
||||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
|
|
||||||
lines = ["<skills>"]
|
lines = ["<skills>"]
|
||||||
for s in all_skills:
|
for s in all_skills:
|
||||||
name = escape_xml(s["name"])
|
name = escape_xml(s["name"])
|
||||||
@@ -122,23 +122,23 @@ class SkillsLoader:
|
|||||||
desc = escape_xml(self._get_skill_description(s["name"]))
|
desc = escape_xml(self._get_skill_description(s["name"]))
|
||||||
skill_meta = self._get_skill_meta(s["name"])
|
skill_meta = self._get_skill_meta(s["name"])
|
||||||
available = self._check_requirements(skill_meta)
|
available = self._check_requirements(skill_meta)
|
||||||
|
|
||||||
lines.append(f" <skill available=\"{str(available).lower()}\">")
|
lines.append(f" <skill available=\"{str(available).lower()}\">")
|
||||||
lines.append(f" <name>{name}</name>")
|
lines.append(f" <name>{name}</name>")
|
||||||
lines.append(f" <description>{desc}</description>")
|
lines.append(f" <description>{desc}</description>")
|
||||||
lines.append(f" <location>{path}</location>")
|
lines.append(f" <location>{path}</location>")
|
||||||
|
|
||||||
# Show missing requirements for unavailable skills
|
# Show missing requirements for unavailable skills
|
||||||
if not available:
|
if not available:
|
||||||
missing = self._get_missing_requirements(skill_meta)
|
missing = self._get_missing_requirements(skill_meta)
|
||||||
if missing:
|
if missing:
|
||||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||||
|
|
||||||
lines.append(f" </skill>")
|
lines.append(" </skill>")
|
||||||
lines.append("</skills>")
|
lines.append("</skills>")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
||||||
"""Get a description of missing requirements."""
|
"""Get a description of missing requirements."""
|
||||||
missing = []
|
missing = []
|
||||||
@@ -150,14 +150,14 @@ class SkillsLoader:
|
|||||||
if not os.environ.get(env):
|
if not os.environ.get(env):
|
||||||
missing.append(f"ENV: {env}")
|
missing.append(f"ENV: {env}")
|
||||||
return ", ".join(missing)
|
return ", ".join(missing)
|
||||||
|
|
||||||
def _get_skill_description(self, name: str) -> str:
|
def _get_skill_description(self, name: str) -> str:
|
||||||
"""Get the description of a skill from its frontmatter."""
|
"""Get the description of a skill from its frontmatter."""
|
||||||
meta = self.get_skill_metadata(name)
|
meta = self.get_skill_metadata(name)
|
||||||
if meta and meta.get("description"):
|
if meta and meta.get("description"):
|
||||||
return meta["description"]
|
return meta["description"]
|
||||||
return name # Fallback to skill name
|
return name # Fallback to skill name
|
||||||
|
|
||||||
def _strip_frontmatter(self, content: str) -> str:
|
def _strip_frontmatter(self, content: str) -> str:
|
||||||
"""Remove YAML frontmatter from markdown content."""
|
"""Remove YAML frontmatter from markdown content."""
|
||||||
if content.startswith("---"):
|
if content.startswith("---"):
|
||||||
@@ -165,7 +165,7 @@ class SkillsLoader:
|
|||||||
if match:
|
if match:
|
||||||
return content[match.end():].strip()
|
return content[match.end():].strip()
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
||||||
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
||||||
try:
|
try:
|
||||||
@@ -173,7 +173,7 @@ class SkillsLoader:
|
|||||||
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
|
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _check_requirements(self, skill_meta: dict) -> bool:
|
def _check_requirements(self, skill_meta: dict) -> bool:
|
||||||
"""Check if skill requirements are met (bins, env vars)."""
|
"""Check if skill requirements are met (bins, env vars)."""
|
||||||
requires = skill_meta.get("requires", {})
|
requires = skill_meta.get("requires", {})
|
||||||
@@ -184,12 +184,12 @@ class SkillsLoader:
|
|||||||
if not os.environ.get(env):
|
if not os.environ.get(env):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_skill_meta(self, name: str) -> dict:
|
def _get_skill_meta(self, name: str) -> dict:
|
||||||
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
||||||
meta = self.get_skill_metadata(name) or {}
|
meta = self.get_skill_metadata(name) or {}
|
||||||
return self._parse_nanobot_metadata(meta.get("metadata", ""))
|
return self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||||
|
|
||||||
def get_always_skills(self) -> list[str]:
|
def get_always_skills(self) -> list[str]:
|
||||||
"""Get skills marked as always=true that meet requirements."""
|
"""Get skills marked as always=true that meet requirements."""
|
||||||
result = []
|
result = []
|
||||||
@@ -199,21 +199,21 @@ class SkillsLoader:
|
|||||||
if skill_meta.get("always") or meta.get("always"):
|
if skill_meta.get("always") or meta.get("always"):
|
||||||
result.append(s["name"])
|
result.append(s["name"])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_skill_metadata(self, name: str) -> dict | None:
|
def get_skill_metadata(self, name: str) -> dict | None:
|
||||||
"""
|
"""
|
||||||
Get metadata from a skill's frontmatter.
|
Get metadata from a skill's frontmatter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Skill name.
|
name: Skill name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Metadata dict or None.
|
Metadata dict or None.
|
||||||
"""
|
"""
|
||||||
content = self.load_skill(name)
|
content = self.load_skill(name)
|
||||||
if not content:
|
if not content:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if content.startswith("---"):
|
if content.startswith("---"):
|
||||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
@@ -224,5 +224,5 @@ class SkillsLoader:
|
|||||||
key, value = line.split(":", 1)
|
key, value = line.split(":", 1)
|
||||||
metadata[key.strip()] = value.strip().strip('"\'')
|
metadata[key.strip()] = value.strip().strip('"\'')
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -8,18 +8,19 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
|
||||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
|
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
|
||||||
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|
||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
"""Manages background subagent execution."""
|
"""Manages background subagent execution."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
@@ -44,7 +45,7 @@ class SubagentManager:
|
|||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
@@ -73,10 +74,10 @@ class SubagentManager:
|
|||||||
del self._session_tasks[session_key]
|
del self._session_tasks[session_key]
|
||||||
|
|
||||||
bg_task.add_done_callback(_cleanup)
|
bg_task.add_done_callback(_cleanup)
|
||||||
|
|
||||||
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
||||||
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
||||||
|
|
||||||
async def _run_subagent(
|
async def _run_subagent(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -86,7 +87,7 @@ class SubagentManager:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the subagent task and announce the result."""
|
"""Execute the subagent task and announce the result."""
|
||||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build subagent tools (no message tool, no spawn tool)
|
# Build subagent tools (no message tool, no spawn tool)
|
||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
@@ -103,22 +104,22 @@ class SubagentManager:
|
|||||||
))
|
))
|
||||||
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
tools.register(WebFetchTool())
|
tools.register(WebFetchTool())
|
||||||
|
|
||||||
# Build messages with subagent-specific prompt
|
# Build messages with subagent-specific prompt
|
||||||
system_prompt = self._build_subagent_prompt(task)
|
system_prompt = self._build_subagent_prompt(task)
|
||||||
messages: list[dict[str, Any]] = [
|
messages: list[dict[str, Any]] = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": task},
|
{"role": "user", "content": task},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run agent loop (limited iterations)
|
# Run agent loop (limited iterations)
|
||||||
max_iterations = 15
|
max_iterations = 15
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_result: str | None = None
|
final_result: str | None = None
|
||||||
|
|
||||||
while iteration < max_iterations:
|
while iteration < max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools.get_definitions(),
|
tools=tools.get_definitions(),
|
||||||
@@ -126,7 +127,7 @@ class SubagentManager:
|
|||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
# Add assistant message with tool calls
|
# Add assistant message with tool calls
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
@@ -145,7 +146,7 @@ class SubagentManager:
|
|||||||
"content": response.content or "",
|
"content": response.content or "",
|
||||||
"tool_calls": tool_call_dicts,
|
"tool_calls": tool_call_dicts,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Execute tools
|
# Execute tools
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||||
@@ -160,18 +161,18 @@ class SubagentManager:
|
|||||||
else:
|
else:
|
||||||
final_result = response.content
|
final_result = response.content
|
||||||
break
|
break
|
||||||
|
|
||||||
if final_result is None:
|
if final_result is None:
|
||||||
final_result = "Task completed but no final response was generated."
|
final_result = "Task completed but no final response was generated."
|
||||||
|
|
||||||
logger.info("Subagent [{}] completed successfully", task_id)
|
logger.info("Subagent [{}] completed successfully", task_id)
|
||||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error: {str(e)}"
|
error_msg = f"Error: {str(e)}"
|
||||||
logger.error("Subagent [{}] failed: {}", task_id, e)
|
logger.error("Subagent [{}] failed: {}", task_id, e)
|
||||||
await self._announce_result(task_id, label, task, error_msg, origin, "error")
|
await self._announce_result(task_id, label, task, error_msg, origin, "error")
|
||||||
|
|
||||||
async def _announce_result(
|
async def _announce_result(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -183,7 +184,7 @@ class SubagentManager:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Announce the subagent result to the main agent via the message bus."""
|
"""Announce the subagent result to the main agent via the message bus."""
|
||||||
status_text = "completed successfully" if status == "ok" else "failed"
|
status_text = "completed successfully" if status == "ok" else "failed"
|
||||||
|
|
||||||
announce_content = f"""[Subagent '{label}' {status_text}]
|
announce_content = f"""[Subagent '{label}' {status_text}]
|
||||||
|
|
||||||
Task: {task}
|
Task: {task}
|
||||||
@@ -192,7 +193,7 @@ Result:
|
|||||||
{result}
|
{result}
|
||||||
|
|
||||||
Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs."""
|
Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs."""
|
||||||
|
|
||||||
# Inject as system message to trigger main agent
|
# Inject as system message to trigger main agent
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel="system",
|
channel="system",
|
||||||
@@ -200,14 +201,14 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
|||||||
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
||||||
content=announce_content,
|
content=announce_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||||
|
|
||||||
def _build_subagent_prompt(self, task: str) -> str:
|
def _build_subagent_prompt(self, task: str) -> str:
|
||||||
"""Build a focused system prompt for the subagent."""
|
"""Build a focused system prompt for the subagent."""
|
||||||
from datetime import datetime
|
|
||||||
import time as _time
|
import time as _time
|
||||||
|
from datetime import datetime
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||||
tz = _time.strftime("%Z") or "UTC"
|
tz = _time.strftime("%Z") or "UTC"
|
||||||
|
|
||||||
@@ -240,7 +241,7 @@ Your workspace is at: {self.workspace}
|
|||||||
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
||||||
|
|
||||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
When you have completed the task, provide a clear summary of your findings or actions."""
|
||||||
|
|
||||||
async def cancel_by_session(self, session_key: str) -> int:
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
"""Cancel all subagents for the given session. Returns count cancelled."""
|
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||||
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ from typing import Any
|
|||||||
class Tool(ABC):
|
class Tool(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for agent tools.
|
Abstract base class for agent tools.
|
||||||
|
|
||||||
Tools are capabilities that the agent can use to interact with
|
Tools are capabilities that the agent can use to interact with
|
||||||
the environment, such as reading files, executing commands, etc.
|
the environment, such as reading files, executing commands, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TYPE_MAP = {
|
_TYPE_MAP = {
|
||||||
"string": str,
|
"string": str,
|
||||||
"integer": int,
|
"integer": int,
|
||||||
@@ -20,33 +20,33 @@ class Tool(ABC):
|
|||||||
"array": list,
|
"array": list,
|
||||||
"object": dict,
|
"object": dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Tool name used in function calls."""
|
"""Tool name used in function calls."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
"""Description of what the tool does."""
|
"""Description of what the tool does."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
"""JSON Schema for tool parameters."""
|
"""JSON Schema for tool parameters."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, **kwargs: Any) -> str:
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Execute the tool with given parameters.
|
Execute the tool with given parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Tool-specific parameters.
|
**kwargs: Tool-specific parameters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
String result of the tool execution.
|
String result of the tool execution.
|
||||||
"""
|
"""
|
||||||
@@ -63,7 +63,7 @@ class Tool(ABC):
|
|||||||
t, label = schema.get("type"), path or "parameter"
|
t, label = schema.get("type"), path or "parameter"
|
||||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
||||||
return [f"{label} should be {t}"]
|
return [f"{label} should be {t}"]
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
if "enum" in schema and val not in schema["enum"]:
|
if "enum" in schema and val not in schema["enum"]:
|
||||||
errors.append(f"{label} must be one of {schema['enum']}")
|
errors.append(f"{label} must be one of {schema['enum']}")
|
||||||
@@ -84,12 +84,14 @@ class Tool(ABC):
|
|||||||
errors.append(f"missing required {path + '.' + k if path else k}")
|
errors.append(f"missing required {path + '.' + k if path else k}")
|
||||||
for k, v in val.items():
|
for k, v in val.items():
|
||||||
if k in props:
|
if k in props:
|
||||||
errors.extend(self._validate(v, props[k], path + '.' + k if path else k))
|
errors.extend(self._validate(v, props[k], path + "." + k if path else k))
|
||||||
if t == "array" and "items" in schema:
|
if t == "array" and "items" in schema:
|
||||||
for i, item in enumerate(val):
|
for i, item in enumerate(val):
|
||||||
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]"))
|
errors.extend(
|
||||||
|
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
|
||||||
|
)
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
def to_schema(self) -> dict[str, Any]:
|
def to_schema(self) -> dict[str, Any]:
|
||||||
"""Convert tool to OpenAI function schema format."""
|
"""Convert tool to OpenAI function schema format."""
|
||||||
return {
|
return {
|
||||||
@@ -98,5 +100,5 @@ class Tool(ABC):
|
|||||||
"name": self.name,
|
"name": self.name,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"parameters": self.parameters,
|
"parameters": self.parameters,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,25 +9,25 @@ from nanobot.cron.types import CronSchedule
|
|||||||
|
|
||||||
class CronTool(Tool):
|
class CronTool(Tool):
|
||||||
"""Tool to schedule reminders and recurring tasks."""
|
"""Tool to schedule reminders and recurring tasks."""
|
||||||
|
|
||||||
def __init__(self, cron_service: CronService):
|
def __init__(self, cron_service: CronService):
|
||||||
self._cron = cron_service
|
self._cron = cron_service
|
||||||
self._channel = ""
|
self._channel = ""
|
||||||
self._chat_id = ""
|
self._chat_id = ""
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
def set_context(self, channel: str, chat_id: str) -> None:
|
||||||
"""Set the current session context for delivery."""
|
"""Set the current session context for delivery."""
|
||||||
self._channel = channel
|
self._channel = channel
|
||||||
self._chat_id = chat_id
|
self._chat_id = chat_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "cron"
|
return "cron"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Schedule reminders and recurring tasks. Actions: add, list, remove."
|
return "Schedule reminders and recurring tasks. Actions: add, list, remove."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -36,36 +36,30 @@ class CronTool(Tool):
|
|||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["add", "list", "remove"],
|
"enum": ["add", "list", "remove"],
|
||||||
"description": "Action to perform"
|
"description": "Action to perform",
|
||||||
},
|
|
||||||
"message": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Reminder message (for add)"
|
|
||||||
},
|
},
|
||||||
|
"message": {"type": "string", "description": "Reminder message (for add)"},
|
||||||
"every_seconds": {
|
"every_seconds": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Interval in seconds (for recurring tasks)"
|
"description": "Interval in seconds (for recurring tasks)",
|
||||||
},
|
},
|
||||||
"cron_expr": {
|
"cron_expr": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
|
||||||
},
|
},
|
||||||
"tz": {
|
"tz": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')"
|
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
|
||||||
},
|
},
|
||||||
"at": {
|
"at": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
|
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
|
||||||
},
|
},
|
||||||
"job_id": {
|
"job_id": {"type": "string", "description": "Job ID (for remove)"},
|
||||||
"type": "string",
|
|
||||||
"description": "Job ID (for remove)"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["action"]
|
"required": ["action"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
action: str,
|
action: str,
|
||||||
@@ -75,7 +69,7 @@ class CronTool(Tool):
|
|||||||
tz: str | None = None,
|
tz: str | None = None,
|
||||||
at: str | None = None,
|
at: str | None = None,
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
**kwargs: Any
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if action == "add":
|
if action == "add":
|
||||||
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
||||||
@@ -84,7 +78,7 @@ class CronTool(Tool):
|
|||||||
elif action == "remove":
|
elif action == "remove":
|
||||||
return self._remove_job(job_id)
|
return self._remove_job(job_id)
|
||||||
return f"Unknown action: {action}"
|
return f"Unknown action: {action}"
|
||||||
|
|
||||||
def _add_job(
|
def _add_job(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@@ -101,11 +95,12 @@ class CronTool(Tool):
|
|||||||
return "Error: tz can only be used with cron_expr"
|
return "Error: tz can only be used with cron_expr"
|
||||||
if tz:
|
if tz:
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ZoneInfo(tz)
|
ZoneInfo(tz)
|
||||||
except (KeyError, Exception):
|
except (KeyError, Exception):
|
||||||
return f"Error: unknown timezone '{tz}'"
|
return f"Error: unknown timezone '{tz}'"
|
||||||
|
|
||||||
# Build schedule
|
# Build schedule
|
||||||
delete_after = False
|
delete_after = False
|
||||||
if every_seconds:
|
if every_seconds:
|
||||||
@@ -114,13 +109,14 @@ class CronTool(Tool):
|
|||||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
||||||
elif at:
|
elif at:
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dt = datetime.fromisoformat(at)
|
dt = datetime.fromisoformat(at)
|
||||||
at_ms = int(dt.timestamp() * 1000)
|
at_ms = int(dt.timestamp() * 1000)
|
||||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||||
delete_after = True
|
delete_after = True
|
||||||
else:
|
else:
|
||||||
return "Error: either every_seconds, cron_expr, or at is required"
|
return "Error: either every_seconds, cron_expr, or at is required"
|
||||||
|
|
||||||
job = self._cron.add_job(
|
job = self._cron.add_job(
|
||||||
name=message[:30],
|
name=message[:30],
|
||||||
schedule=schedule,
|
schedule=schedule,
|
||||||
@@ -131,14 +127,14 @@ class CronTool(Tool):
|
|||||||
delete_after_run=delete_after,
|
delete_after_run=delete_after,
|
||||||
)
|
)
|
||||||
return f"Created job '{job.name}' (id: {job.id})"
|
return f"Created job '{job.name}' (id: {job.id})"
|
||||||
|
|
||||||
def _list_jobs(self) -> str:
|
def _list_jobs(self) -> str:
|
||||||
jobs = self._cron.list_jobs()
|
jobs = self._cron.list_jobs()
|
||||||
if not jobs:
|
if not jobs:
|
||||||
return "No scheduled jobs."
|
return "No scheduled jobs."
|
||||||
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
|
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
|
||||||
return "Scheduled jobs:\n" + "\n".join(lines)
|
return "Scheduled jobs:\n" + "\n".join(lines)
|
||||||
|
|
||||||
def _remove_job(self, job_id: str | None) -> str:
|
def _remove_job(self, job_id: str | None) -> str:
|
||||||
if not job_id:
|
if not job_id:
|
||||||
return "Error: job_id is required for remove"
|
return "Error: job_id is required for remove"
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ from typing import Any
|
|||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path:
|
def _resolve_path(
|
||||||
|
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
|
||||||
|
) -> Path:
|
||||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||||
p = Path(path).expanduser()
|
p = Path(path).expanduser()
|
||||||
if not p.is_absolute() and workspace:
|
if not p.is_absolute() and workspace:
|
||||||
@@ -31,24 +33,19 @@ class ReadFileTool(Tool):
|
|||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "read_file"
|
return "read_file"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Read the contents of a file at the given path."
|
return "Read the contents of a file at the given path."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"path": {"type": "string", "description": "The file path to read"}},
|
||||||
"path": {
|
"required": ["path"],
|
||||||
"type": "string",
|
|
||||||
"description": "The file path to read"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
@@ -75,28 +72,22 @@ class WriteFileTool(Tool):
|
|||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "write_file"
|
return "write_file"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Write content to a file at the given path. Creates parent directories if needed."
|
return "Write content to a file at the given path. Creates parent directories if needed."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"path": {
|
"path": {"type": "string", "description": "The file path to write to"},
|
||||||
"type": "string",
|
"content": {"type": "string", "description": "The content to write"},
|
||||||
"description": "The file path to write to"
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The content to write"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["path", "content"]
|
"required": ["path", "content"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
@@ -119,32 +110,23 @@ class EditFileTool(Tool):
|
|||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "edit_file"
|
return "edit_file"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"path": {
|
"path": {"type": "string", "description": "The file path to edit"},
|
||||||
"type": "string",
|
"old_text": {"type": "string", "description": "The exact text to find and replace"},
|
||||||
"description": "The file path to edit"
|
"new_text": {"type": "string", "description": "The text to replace with"},
|
||||||
},
|
|
||||||
"old_text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The exact text to find and replace"
|
|
||||||
},
|
|
||||||
"new_text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The text to replace with"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["path", "old_text", "new_text"]
|
"required": ["path", "old_text", "new_text"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
@@ -184,13 +166,19 @@ class EditFileTool(Tool):
|
|||||||
best_ratio, best_start = ratio, i
|
best_ratio, best_start = ratio, i
|
||||||
|
|
||||||
if best_ratio > 0.5:
|
if best_ratio > 0.5:
|
||||||
diff = "\n".join(difflib.unified_diff(
|
diff = "\n".join(
|
||||||
old_lines, lines[best_start : best_start + window],
|
difflib.unified_diff(
|
||||||
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})",
|
old_lines,
|
||||||
lineterm="",
|
lines[best_start : best_start + window],
|
||||||
))
|
fromfile="old_text (provided)",
|
||||||
|
tofile=f"{path} (actual, line {best_start + 1})",
|
||||||
|
lineterm="",
|
||||||
|
)
|
||||||
|
)
|
||||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
return (
|
||||||
|
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ListDirTool(Tool):
|
class ListDirTool(Tool):
|
||||||
@@ -203,24 +191,19 @@ class ListDirTool(Tool):
|
|||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "list_dir"
|
return "list_dir"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "List the contents of a directory."
|
return "List the contents of a directory."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"path": {"type": "string", "description": "The directory path to list"}},
|
||||||
"path": {
|
"required": ["path"],
|
||||||
"type": "string",
|
|
||||||
"description": "The directory path to list"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
|
|||||||
@@ -8,33 +8,33 @@ from nanobot.agent.tools.base import Tool
|
|||||||
class ToolRegistry:
|
class ToolRegistry:
|
||||||
"""
|
"""
|
||||||
Registry for agent tools.
|
Registry for agent tools.
|
||||||
|
|
||||||
Allows dynamic registration and execution of tools.
|
Allows dynamic registration and execution of tools.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tools: dict[str, Tool] = {}
|
self._tools: dict[str, Tool] = {}
|
||||||
|
|
||||||
def register(self, tool: Tool) -> None:
|
def register(self, tool: Tool) -> None:
|
||||||
"""Register a tool."""
|
"""Register a tool."""
|
||||||
self._tools[tool.name] = tool
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
def unregister(self, name: str) -> None:
|
def unregister(self, name: str) -> None:
|
||||||
"""Unregister a tool by name."""
|
"""Unregister a tool by name."""
|
||||||
self._tools.pop(name, None)
|
self._tools.pop(name, None)
|
||||||
|
|
||||||
def get(self, name: str) -> Tool | None:
|
def get(self, name: str) -> Tool | None:
|
||||||
"""Get a tool by name."""
|
"""Get a tool by name."""
|
||||||
return self._tools.get(name)
|
return self._tools.get(name)
|
||||||
|
|
||||||
def has(self, name: str) -> bool:
|
def has(self, name: str) -> bool:
|
||||||
"""Check if a tool is registered."""
|
"""Check if a tool is registered."""
|
||||||
return name in self._tools
|
return name in self._tools
|
||||||
|
|
||||||
def get_definitions(self) -> list[dict[str, Any]]:
|
def get_definitions(self) -> list[dict[str, Any]]:
|
||||||
"""Get all tool definitions in OpenAI format."""
|
"""Get all tool definitions in OpenAI format."""
|
||||||
return [tool.to_schema() for tool in self._tools.values()]
|
return [tool.to_schema() for tool in self._tools.values()]
|
||||||
|
|
||||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||||
"""Execute a tool by name with given parameters."""
|
"""Execute a tool by name with given parameters."""
|
||||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
@@ -53,14 +53,14 @@ class ToolRegistry:
|
|||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing {name}: {str(e)}" + _HINT
|
return f"Error executing {name}: {str(e)}" + _HINT
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_names(self) -> list[str]:
|
def tool_names(self) -> list[str]:
|
||||||
"""Get list of registered tool names."""
|
"""Get list of registered tool names."""
|
||||||
return list(self._tools.keys())
|
return list(self._tools.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._tools)
|
return len(self._tools)
|
||||||
|
|
||||||
def __contains__(self, name: str) -> bool:
|
def __contains__(self, name: str) -> bool:
|
||||||
return name in self._tools
|
return name in self._tools
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from nanobot.agent.tools.base import Tool
|
|||||||
|
|
||||||
class ExecTool(Tool):
|
class ExecTool(Tool):
|
||||||
"""Tool to execute shell commands."""
|
"""Tool to execute shell commands."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
timeout: int = 60,
|
timeout: int = 60,
|
||||||
@@ -37,15 +37,15 @@ class ExecTool(Tool):
|
|||||||
self.allow_patterns = allow_patterns or []
|
self.allow_patterns = allow_patterns or []
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self.path_append = path_append
|
self.path_append = path_append
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "exec"
|
return "exec"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Execute a shell command and return its output. Use with caution."
|
return "Execute a shell command and return its output. Use with caution."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Spawn tool for creating background subagents."""
|
"""Spawn tool for creating background subagents."""
|
||||||
|
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
|
||||||
@@ -10,23 +10,23 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class SpawnTool(Tool):
|
class SpawnTool(Tool):
|
||||||
"""Tool to spawn a subagent for background task execution."""
|
"""Tool to spawn a subagent for background task execution."""
|
||||||
|
|
||||||
def __init__(self, manager: "SubagentManager"):
|
def __init__(self, manager: "SubagentManager"):
|
||||||
self._manager = manager
|
self._manager = manager
|
||||||
self._origin_channel = "cli"
|
self._origin_channel = "cli"
|
||||||
self._origin_chat_id = "direct"
|
self._origin_chat_id = "direct"
|
||||||
self._session_key = "cli:direct"
|
self._session_key = "cli:direct"
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
def set_context(self, channel: str, chat_id: str) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
self._origin_channel = channel
|
self._origin_channel = channel
|
||||||
self._origin_chat_id = chat_id
|
self._origin_chat_id = chat_id
|
||||||
self._session_key = f"{channel}:{chat_id}"
|
self._session_key = f"{channel}:{chat_id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "spawn"
|
return "spawn"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
@@ -34,7 +34,7 @@ class SpawnTool(Tool):
|
|||||||
"Use this for complex or time-consuming tasks that can run independently. "
|
"Use this for complex or time-consuming tasks that can run independently. "
|
||||||
"The subagent will complete the task and report back when done."
|
"The subagent will complete the task and report back when done."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -51,7 +51,7 @@ class SpawnTool(Tool):
|
|||||||
},
|
},
|
||||||
"required": ["task"],
|
"required": ["task"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
|
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
|
||||||
"""Spawn a subagent to execute the given task."""
|
"""Spawn a subagent to execute the given task."""
|
||||||
return await self._manager.spawn(
|
return await self._manager.spawn(
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def _validate_url(url: str) -> tuple[bool, str]:
|
|||||||
|
|
||||||
class WebSearchTool(Tool):
|
class WebSearchTool(Tool):
|
||||||
"""Search the web using Brave Search API."""
|
"""Search the web using Brave Search API."""
|
||||||
|
|
||||||
name = "web_search"
|
name = "web_search"
|
||||||
description = "Search the web. Returns titles, URLs, and snippets."
|
description = "Search the web. Returns titles, URLs, and snippets."
|
||||||
parameters = {
|
parameters = {
|
||||||
@@ -56,7 +56,7 @@ class WebSearchTool(Tool):
|
|||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
||||||
self._init_api_key = api_key
|
self._init_api_key = api_key
|
||||||
self.max_results = max_results
|
self.max_results = max_results
|
||||||
@@ -73,7 +73,7 @@ class WebSearchTool(Tool):
|
|||||||
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
|
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
|
||||||
"(or export BRAVE_API_KEY), then restart the gateway."
|
"(or export BRAVE_API_KEY), then restart the gateway."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
n = min(max(count or self.max_results, 1), 10)
|
n = min(max(count or self.max_results, 1), 10)
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
@@ -84,11 +84,11 @@ class WebSearchTool(Tool):
|
|||||||
timeout=10.0
|
timeout=10.0
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
results = r.json().get("web", {}).get("results", [])
|
results = r.json().get("web", {}).get("results", [])
|
||||||
if not results:
|
if not results:
|
||||||
return f"No results for: {query}"
|
return f"No results for: {query}"
|
||||||
|
|
||||||
lines = [f"Results for: {query}\n"]
|
lines = [f"Results for: {query}\n"]
|
||||||
for i, item in enumerate(results[:n], 1):
|
for i, item in enumerate(results[:n], 1):
|
||||||
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
||||||
@@ -101,7 +101,7 @@ class WebSearchTool(Tool):
|
|||||||
|
|
||||||
class WebFetchTool(Tool):
|
class WebFetchTool(Tool):
|
||||||
"""Fetch and extract content from a URL using Readability."""
|
"""Fetch and extract content from a URL using Readability."""
|
||||||
|
|
||||||
name = "web_fetch"
|
name = "web_fetch"
|
||||||
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
||||||
parameters = {
|
parameters = {
|
||||||
@@ -113,10 +113,10 @@ class WebFetchTool(Tool):
|
|||||||
},
|
},
|
||||||
"required": ["url"]
|
"required": ["url"]
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, max_chars: int = 50000):
|
def __init__(self, max_chars: int = 50000):
|
||||||
self.max_chars = max_chars
|
self.max_chars = max_chars
|
||||||
|
|
||||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
||||||
from readability import Document
|
from readability import Document
|
||||||
|
|
||||||
@@ -135,9 +135,9 @@ class WebFetchTool(Tool):
|
|||||||
) as client:
|
) as client:
|
||||||
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
ctype = r.headers.get("content-type", "")
|
ctype = r.headers.get("content-type", "")
|
||||||
|
|
||||||
# JSON
|
# JSON
|
||||||
if "application/json" in ctype:
|
if "application/json" in ctype:
|
||||||
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||||
@@ -149,16 +149,16 @@ class WebFetchTool(Tool):
|
|||||||
extractor = "readability"
|
extractor = "readability"
|
||||||
else:
|
else:
|
||||||
text, extractor = r.text, "raw"
|
text, extractor = r.text, "raw"
|
||||||
|
|
||||||
truncated = len(text) > max_chars
|
truncated = len(text) > max_chars
|
||||||
if truncated:
|
if truncated:
|
||||||
text = text[:max_chars]
|
text = text[:max_chars]
|
||||||
|
|
||||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||||
|
|
||||||
def _to_markdown(self, html: str) -> str:
|
def _to_markdown(self, html: str) -> str:
|
||||||
"""Convert HTML to markdown."""
|
"""Convert HTML to markdown."""
|
||||||
# Convert links, headings, lists before stripping tags
|
# Convert links, headings, lists before stripping tags
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any
|
|||||||
@dataclass
|
@dataclass
|
||||||
class InboundMessage:
|
class InboundMessage:
|
||||||
"""Message received from a chat channel."""
|
"""Message received from a chat channel."""
|
||||||
|
|
||||||
channel: str # telegram, discord, slack, whatsapp
|
channel: str # telegram, discord, slack, whatsapp
|
||||||
sender_id: str # User identifier
|
sender_id: str # User identifier
|
||||||
chat_id: str # Chat/channel identifier
|
chat_id: str # Chat/channel identifier
|
||||||
@@ -17,7 +17,7 @@ class InboundMessage:
|
|||||||
media: list[str] = field(default_factory=list) # Media URLs
|
media: list[str] = field(default_factory=list) # Media URLs
|
||||||
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
||||||
session_key_override: str | None = None # Optional override for thread-scoped sessions
|
session_key_override: str | None = None # Optional override for thread-scoped sessions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session_key(self) -> str:
|
def session_key(self) -> str:
|
||||||
"""Unique key for session identification."""
|
"""Unique key for session identification."""
|
||||||
@@ -27,7 +27,7 @@ class InboundMessage:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class OutboundMessage:
|
class OutboundMessage:
|
||||||
"""Message to send to a chat channel."""
|
"""Message to send to a chat channel."""
|
||||||
|
|
||||||
channel: str
|
channel: str
|
||||||
chat_id: str
|
chat_id: str
|
||||||
content: str
|
content: str
|
||||||
|
|||||||
@@ -12,17 +12,17 @@ from nanobot.bus.queue import MessageBus
|
|||||||
class BaseChannel(ABC):
|
class BaseChannel(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for chat channel implementations.
|
Abstract base class for chat channel implementations.
|
||||||
|
|
||||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||||
to integrate with the nanobot message bus.
|
to integrate with the nanobot message bus.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "base"
|
name: str = "base"
|
||||||
|
|
||||||
def __init__(self, config: Any, bus: MessageBus):
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
"""
|
"""
|
||||||
Initialize the channel.
|
Initialize the channel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Channel-specific configuration.
|
config: Channel-specific configuration.
|
||||||
bus: The message bus for communication.
|
bus: The message bus for communication.
|
||||||
@@ -30,50 +30,50 @@ class BaseChannel(ABC):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""
|
"""
|
||||||
Start the channel and begin listening for messages.
|
Start the channel and begin listening for messages.
|
||||||
|
|
||||||
This should be a long-running async task that:
|
This should be a long-running async task that:
|
||||||
1. Connects to the chat platform
|
1. Connects to the chat platform
|
||||||
2. Listens for incoming messages
|
2. Listens for incoming messages
|
||||||
3. Forwards messages to the bus via _handle_message()
|
3. Forwards messages to the bus via _handle_message()
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the channel and clean up resources."""
|
"""Stop the channel and clean up resources."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""
|
"""
|
||||||
Send a message through this channel.
|
Send a message through this channel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The message to send.
|
msg: The message to send.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_allowed(self, sender_id: str) -> bool:
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a sender is allowed to use this bot.
|
Check if a sender is allowed to use this bot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sender_id: The sender's identifier.
|
sender_id: The sender's identifier.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if allowed, False otherwise.
|
True if allowed, False otherwise.
|
||||||
"""
|
"""
|
||||||
allow_list = getattr(self.config, "allow_from", [])
|
allow_list = getattr(self.config, "allow_from", [])
|
||||||
|
|
||||||
# If no allow list, allow everyone
|
# If no allow list, allow everyone
|
||||||
if not allow_list:
|
if not allow_list:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
sender_str = str(sender_id)
|
sender_str = str(sender_id)
|
||||||
if sender_str in allow_list:
|
if sender_str in allow_list:
|
||||||
return True
|
return True
|
||||||
@@ -82,7 +82,7 @@ class BaseChannel(ABC):
|
|||||||
if part and part in allow_list:
|
if part and part in allow_list:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _handle_message(
|
async def _handle_message(
|
||||||
self,
|
self,
|
||||||
sender_id: str,
|
sender_id: str,
|
||||||
@@ -94,9 +94,9 @@ class BaseChannel(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Handle an incoming message from the chat platform.
|
Handle an incoming message from the chat platform.
|
||||||
|
|
||||||
This method checks permissions and forwards to the bus.
|
This method checks permissions and forwards to the bus.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sender_id: The sender's identifier.
|
sender_id: The sender's identifier.
|
||||||
chat_id: The chat/channel identifier.
|
chat_id: The chat/channel identifier.
|
||||||
@@ -112,7 +112,7 @@ class BaseChannel(ABC):
|
|||||||
sender_id, self.name,
|
sender_id, self.name,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel=self.name,
|
channel=self.name,
|
||||||
sender_id=str(sender_id),
|
sender_id=str(sender_id),
|
||||||
@@ -122,9 +122,9 @@ class BaseChannel(ABC):
|
|||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
session_key_override=session_key,
|
session_key_override=session_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""Check if the channel is running."""
|
"""Check if the channel is running."""
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import json
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@@ -15,11 +15,11 @@ from nanobot.config.schema import DingTalkConfig
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from dingtalk_stream import (
|
from dingtalk_stream import (
|
||||||
DingTalkStreamClient,
|
AckMessage,
|
||||||
Credential,
|
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
CallbackMessage,
|
CallbackMessage,
|
||||||
AckMessage,
|
Credential,
|
||||||
|
DingTalkStreamClient,
|
||||||
)
|
)
|
||||||
from dingtalk_stream.chatbot import ChatbotMessage
|
from dingtalk_stream.chatbot import ChatbotMessage
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import DiscordConfig
|
from nanobot.config.schema import DiscordConfig
|
||||||
|
|
||||||
|
|
||||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||||
|
|||||||
@@ -23,12 +23,11 @@ try:
|
|||||||
CreateFileRequestBody,
|
CreateFileRequestBody,
|
||||||
CreateImageRequest,
|
CreateImageRequest,
|
||||||
CreateImageRequestBody,
|
CreateImageRequestBody,
|
||||||
CreateMessageRequest,
|
|
||||||
CreateMessageRequestBody,
|
|
||||||
CreateMessageReactionRequest,
|
CreateMessageReactionRequest,
|
||||||
CreateMessageReactionRequestBody,
|
CreateMessageReactionRequestBody,
|
||||||
|
CreateMessageRequest,
|
||||||
|
CreateMessageRequestBody,
|
||||||
Emoji,
|
Emoji,
|
||||||
GetFileRequest,
|
|
||||||
GetMessageResourceRequest,
|
GetMessageResourceRequest,
|
||||||
P2ImMessageReceiveV1,
|
P2ImMessageReceiveV1,
|
||||||
)
|
)
|
||||||
@@ -70,7 +69,7 @@ def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
|||||||
def _extract_interactive_content(content: dict) -> list[str]:
|
def _extract_interactive_content(content: dict) -> list[str]:
|
||||||
"""Recursively extract text and links from interactive card content."""
|
"""Recursively extract text and links from interactive card content."""
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
try:
|
try:
|
||||||
content = json.loads(content)
|
content = json.loads(content)
|
||||||
@@ -104,19 +103,19 @@ def _extract_interactive_content(content: dict) -> list[str]:
|
|||||||
header_text = header_title.get("content", "") or header_title.get("text", "")
|
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||||
if header_text:
|
if header_text:
|
||||||
parts.append(f"title: {header_text}")
|
parts.append(f"title: {header_text}")
|
||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def _extract_element_content(element: dict) -> list[str]:
|
def _extract_element_content(element: dict) -> list[str]:
|
||||||
"""Extract content from a single card element."""
|
"""Extract content from a single card element."""
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
if not isinstance(element, dict):
|
if not isinstance(element, dict):
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
tag = element.get("tag", "")
|
tag = element.get("tag", "")
|
||||||
|
|
||||||
if tag in ("markdown", "lark_md"):
|
if tag in ("markdown", "lark_md"):
|
||||||
content = element.get("content", "")
|
content = element.get("content", "")
|
||||||
if content:
|
if content:
|
||||||
@@ -177,17 +176,17 @@ def _extract_element_content(element: dict) -> list[str]:
|
|||||||
else:
|
else:
|
||||||
for ne in element.get("elements", []):
|
for ne in element.get("elements", []):
|
||||||
parts.extend(_extract_element_content(ne))
|
parts.extend(_extract_element_content(ne))
|
||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||||
"""Extract text and image keys from Feishu post (rich text) message content.
|
"""Extract text and image keys from Feishu post (rich text) message content.
|
||||||
|
|
||||||
Supports two formats:
|
Supports two formats:
|
||||||
1. Direct format: {"title": "...", "content": [...]}
|
1. Direct format: {"title": "...", "content": [...]}
|
||||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(text, image_keys) - extracted text and list of image keys
|
(text, image_keys) - extracted text and list of image keys
|
||||||
"""
|
"""
|
||||||
@@ -220,26 +219,26 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
|||||||
image_keys.append(img_key)
|
image_keys.append(img_key)
|
||||||
text = " ".join(text_parts).strip() if text_parts else None
|
text = " ".join(text_parts).strip() if text_parts else None
|
||||||
return text, image_keys
|
return text, image_keys
|
||||||
|
|
||||||
# Try direct format first
|
# Try direct format first
|
||||||
if "content" in content_json:
|
if "content" in content_json:
|
||||||
text, images = extract_from_lang(content_json)
|
text, images = extract_from_lang(content_json)
|
||||||
if text or images:
|
if text or images:
|
||||||
return text or "", images
|
return text or "", images
|
||||||
|
|
||||||
# Try localized format
|
# Try localized format
|
||||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||||
lang_content = content_json.get(lang_key)
|
lang_content = content_json.get(lang_key)
|
||||||
text, images = extract_from_lang(lang_content)
|
text, images = extract_from_lang(lang_content)
|
||||||
if text or images:
|
if text or images:
|
||||||
return text or "", images
|
return text or "", images
|
||||||
|
|
||||||
return "", []
|
return "", []
|
||||||
|
|
||||||
|
|
||||||
def _extract_post_text(content_json: dict) -> str:
|
def _extract_post_text(content_json: dict) -> str:
|
||||||
"""Extract plain text from Feishu post (rich text) message content.
|
"""Extract plain text from Feishu post (rich text) message content.
|
||||||
|
|
||||||
Legacy wrapper for _extract_post_content, returns only text.
|
Legacy wrapper for _extract_post_content, returns only text.
|
||||||
"""
|
"""
|
||||||
text, _ = _extract_post_content(content_json)
|
text, _ = _extract_post_content(content_json)
|
||||||
@@ -249,17 +248,17 @@ def _extract_post_text(content_json: dict) -> str:
|
|||||||
class FeishuChannel(BaseChannel):
|
class FeishuChannel(BaseChannel):
|
||||||
"""
|
"""
|
||||||
Feishu/Lark channel using WebSocket long connection.
|
Feishu/Lark channel using WebSocket long connection.
|
||||||
|
|
||||||
Uses WebSocket to receive events - no public IP or webhook required.
|
Uses WebSocket to receive events - no public IP or webhook required.
|
||||||
|
|
||||||
Requires:
|
Requires:
|
||||||
- App ID and App Secret from Feishu Open Platform
|
- App ID and App Secret from Feishu Open Platform
|
||||||
- Bot capability enabled
|
- Bot capability enabled
|
||||||
- Event subscription enabled (im.message.receive_v1)
|
- Event subscription enabled (im.message.receive_v1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "feishu"
|
name = "feishu"
|
||||||
|
|
||||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.config: FeishuConfig = config
|
self.config: FeishuConfig = config
|
||||||
@@ -268,27 +267,27 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._ws_thread: threading.Thread | None = None
|
self._ws_thread: threading.Thread | None = None
|
||||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Feishu bot with WebSocket long connection."""
|
"""Start the Feishu bot with WebSocket long connection."""
|
||||||
if not FEISHU_AVAILABLE:
|
if not FEISHU_AVAILABLE:
|
||||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.app_id or not self.config.app_secret:
|
if not self.config.app_id or not self.config.app_secret:
|
||||||
logger.error("Feishu app_id and app_secret not configured")
|
logger.error("Feishu app_id and app_secret not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
# Create Lark client for sending messages
|
# Create Lark client for sending messages
|
||||||
self._client = lark.Client.builder() \
|
self._client = lark.Client.builder() \
|
||||||
.app_id(self.config.app_id) \
|
.app_id(self.config.app_id) \
|
||||||
.app_secret(self.config.app_secret) \
|
.app_secret(self.config.app_secret) \
|
||||||
.log_level(lark.LogLevel.INFO) \
|
.log_level(lark.LogLevel.INFO) \
|
||||||
.build()
|
.build()
|
||||||
|
|
||||||
# Create event handler (only register message receive, ignore other events)
|
# Create event handler (only register message receive, ignore other events)
|
||||||
event_handler = lark.EventDispatcherHandler.builder(
|
event_handler = lark.EventDispatcherHandler.builder(
|
||||||
self.config.encrypt_key or "",
|
self.config.encrypt_key or "",
|
||||||
@@ -296,7 +295,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
).register_p2_im_message_receive_v1(
|
).register_p2_im_message_receive_v1(
|
||||||
self._on_message_sync
|
self._on_message_sync
|
||||||
).build()
|
).build()
|
||||||
|
|
||||||
# Create WebSocket client for long connection
|
# Create WebSocket client for long connection
|
||||||
self._ws_client = lark.ws.Client(
|
self._ws_client = lark.ws.Client(
|
||||||
self.config.app_id,
|
self.config.app_id,
|
||||||
@@ -304,7 +303,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
log_level=lark.LogLevel.INFO
|
log_level=lark.LogLevel.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start WebSocket client in a separate thread with reconnect loop
|
# Start WebSocket client in a separate thread with reconnect loop
|
||||||
def run_ws():
|
def run_ws():
|
||||||
while self._running:
|
while self._running:
|
||||||
@@ -313,18 +312,19 @@ class FeishuChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Feishu WebSocket error: {}", e)
|
logger.warning("Feishu WebSocket error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
import time; time.sleep(5)
|
import time
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||||
self._ws_thread.start()
|
self._ws_thread.start()
|
||||||
|
|
||||||
logger.info("Feishu bot started with WebSocket long connection")
|
logger.info("Feishu bot started with WebSocket long connection")
|
||||||
logger.info("No public IP required - using WebSocket to receive events")
|
logger.info("No public IP required - using WebSocket to receive events")
|
||||||
|
|
||||||
# Keep running until stopped
|
# Keep running until stopped
|
||||||
while self._running:
|
while self._running:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the Feishu bot."""
|
"""Stop the Feishu bot."""
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -334,7 +334,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error stopping WebSocket client: {}", e)
|
logger.warning("Error stopping WebSocket client: {}", e)
|
||||||
logger.info("Feishu bot stopped")
|
logger.info("Feishu bot stopped")
|
||||||
|
|
||||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||||
try:
|
try:
|
||||||
@@ -345,9 +345,9 @@ class FeishuChannel(BaseChannel):
|
|||||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||||
.build()
|
.build()
|
||||||
).build()
|
).build()
|
||||||
|
|
||||||
response = self._client.im.v1.message_reaction.create(request)
|
response = self._client.im.v1.message_reaction.create(request)
|
||||||
|
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||||
else:
|
else:
|
||||||
@@ -358,15 +358,15 @@ class FeishuChannel(BaseChannel):
|
|||||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||||
"""
|
"""
|
||||||
Add a reaction emoji to a message (non-blocking).
|
Add a reaction emoji to a message (non-blocking).
|
||||||
|
|
||||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||||
"""
|
"""
|
||||||
if not self._client or not Emoji:
|
if not self._client or not Emoji:
|
||||||
return
|
return
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
||||||
|
|
||||||
# Regex to match markdown tables (header + separator + data rows)
|
# Regex to match markdown tables (header + separator + data rows)
|
||||||
_TABLE_RE = re.compile(
|
_TABLE_RE = re.compile(
|
||||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||||
@@ -380,12 +380,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_md_table(table_text: str) -> dict | None:
|
def _parse_md_table(table_text: str) -> dict | None:
|
||||||
"""Parse a markdown table into a Feishu table element."""
|
"""Parse a markdown table into a Feishu table element."""
|
||||||
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()]
|
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
|
||||||
if len(lines) < 3:
|
if len(lines) < 3:
|
||||||
return None
|
return None
|
||||||
split = lambda l: [c.strip() for c in l.strip("|").split("|")]
|
def split(_line: str) -> list[str]:
|
||||||
|
return [c.strip() for c in _line.strip("|").split("|")]
|
||||||
headers = split(lines[0])
|
headers = split(lines[0])
|
||||||
rows = [split(l) for l in lines[2:]]
|
rows = [split(_line) for _line in lines[2:]]
|
||||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||||
for i, h in enumerate(headers)]
|
for i, h in enumerate(headers)]
|
||||||
return {
|
return {
|
||||||
@@ -657,7 +658,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending Feishu message: {}", e)
|
logger.error("Error sending Feishu message: {}", e)
|
||||||
|
|
||||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
||||||
"""
|
"""
|
||||||
Sync handler for incoming messages (called from WebSocket thread).
|
Sync handler for incoming messages (called from WebSocket thread).
|
||||||
@@ -665,7 +666,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
if self._loop and self._loop.is_running():
|
if self._loop and self._loop.is_running():
|
||||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||||
|
|
||||||
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
||||||
"""Handle incoming message from Feishu."""
|
"""Handle incoming message from Feishu."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -16,24 +16,24 @@ from nanobot.config.schema import Config
|
|||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
"""
|
"""
|
||||||
Manages chat channels and coordinates message routing.
|
Manages chat channels and coordinates message routing.
|
||||||
|
|
||||||
Responsibilities:
|
Responsibilities:
|
||||||
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
||||||
- Start/stop channels
|
- Start/stop channels
|
||||||
- Route outbound messages
|
- Route outbound messages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Config, bus: MessageBus):
|
def __init__(self, config: Config, bus: MessageBus):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.channels: dict[str, BaseChannel] = {}
|
self.channels: dict[str, BaseChannel] = {}
|
||||||
self._dispatch_task: asyncio.Task | None = None
|
self._dispatch_task: asyncio.Task | None = None
|
||||||
|
|
||||||
self._init_channels()
|
self._init_channels()
|
||||||
|
|
||||||
def _init_channels(self) -> None:
|
def _init_channels(self) -> None:
|
||||||
"""Initialize channels based on config."""
|
"""Initialize channels based on config."""
|
||||||
|
|
||||||
# Telegram channel
|
# Telegram channel
|
||||||
if self.config.channels.telegram.enabled:
|
if self.config.channels.telegram.enabled:
|
||||||
try:
|
try:
|
||||||
@@ -46,7 +46,7 @@ class ChannelManager:
|
|||||||
logger.info("Telegram channel enabled")
|
logger.info("Telegram channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Telegram channel not available: {}", e)
|
logger.warning("Telegram channel not available: {}", e)
|
||||||
|
|
||||||
# WhatsApp channel
|
# WhatsApp channel
|
||||||
if self.config.channels.whatsapp.enabled:
|
if self.config.channels.whatsapp.enabled:
|
||||||
try:
|
try:
|
||||||
@@ -68,7 +68,7 @@ class ChannelManager:
|
|||||||
logger.info("Discord channel enabled")
|
logger.info("Discord channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Discord channel not available: {}", e)
|
logger.warning("Discord channel not available: {}", e)
|
||||||
|
|
||||||
# Feishu channel
|
# Feishu channel
|
||||||
if self.config.channels.feishu.enabled:
|
if self.config.channels.feishu.enabled:
|
||||||
try:
|
try:
|
||||||
@@ -136,7 +136,7 @@ class ChannelManager:
|
|||||||
logger.info("QQ channel enabled")
|
logger.info("QQ channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("QQ channel not available: {}", e)
|
logger.warning("QQ channel not available: {}", e)
|
||||||
|
|
||||||
# Matrix channel
|
# Matrix channel
|
||||||
if self.config.channels.matrix.enabled:
|
if self.config.channels.matrix.enabled:
|
||||||
try:
|
try:
|
||||||
@@ -148,7 +148,7 @@ class ChannelManager:
|
|||||||
logger.info("Matrix channel enabled")
|
logger.info("Matrix channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Matrix channel not available: {}", e)
|
logger.warning("Matrix channel not available: {}", e)
|
||||||
|
|
||||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||||
"""Start a channel and log any exceptions."""
|
"""Start a channel and log any exceptions."""
|
||||||
try:
|
try:
|
||||||
@@ -161,23 +161,23 @@ class ChannelManager:
|
|||||||
if not self.channels:
|
if not self.channels:
|
||||||
logger.warning("No channels enabled")
|
logger.warning("No channels enabled")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Start outbound dispatcher
|
# Start outbound dispatcher
|
||||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||||
|
|
||||||
# Start channels
|
# Start channels
|
||||||
tasks = []
|
tasks = []
|
||||||
for name, channel in self.channels.items():
|
for name, channel in self.channels.items():
|
||||||
logger.info("Starting {} channel...", name)
|
logger.info("Starting {} channel...", name)
|
||||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||||
|
|
||||||
# Wait for all to complete (they should run forever)
|
# Wait for all to complete (they should run forever)
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def stop_all(self) -> None:
|
async def stop_all(self) -> None:
|
||||||
"""Stop all channels and the dispatcher."""
|
"""Stop all channels and the dispatcher."""
|
||||||
logger.info("Stopping all channels...")
|
logger.info("Stopping all channels...")
|
||||||
|
|
||||||
# Stop dispatcher
|
# Stop dispatcher
|
||||||
if self._dispatch_task:
|
if self._dispatch_task:
|
||||||
self._dispatch_task.cancel()
|
self._dispatch_task.cancel()
|
||||||
@@ -185,7 +185,7 @@ class ChannelManager:
|
|||||||
await self._dispatch_task
|
await self._dispatch_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Stop all channels
|
# Stop all channels
|
||||||
for name, channel in self.channels.items():
|
for name, channel in self.channels.items():
|
||||||
try:
|
try:
|
||||||
@@ -193,24 +193,24 @@ class ChannelManager:
|
|||||||
logger.info("Stopped {} channel", name)
|
logger.info("Stopped {} channel", name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error stopping {}: {}", name, e)
|
logger.error("Error stopping {}: {}", name, e)
|
||||||
|
|
||||||
async def _dispatch_outbound(self) -> None:
|
async def _dispatch_outbound(self) -> None:
|
||||||
"""Dispatch outbound messages to the appropriate channel."""
|
"""Dispatch outbound messages to the appropriate channel."""
|
||||||
logger.info("Outbound dispatcher started")
|
logger.info("Outbound dispatcher started")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(
|
msg = await asyncio.wait_for(
|
||||||
self.bus.consume_outbound(),
|
self.bus.consume_outbound(),
|
||||||
timeout=1.0
|
timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg.metadata.get("_progress"):
|
if msg.metadata.get("_progress"):
|
||||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||||
continue
|
continue
|
||||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
try:
|
try:
|
||||||
@@ -219,16 +219,16 @@ class ChannelManager:
|
|||||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown channel: {}", msg.channel)
|
logger.warning("Unknown channel: {}", msg.channel)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
|
|
||||||
def get_channel(self, name: str) -> BaseChannel | None:
|
def get_channel(self, name: str) -> BaseChannel | None:
|
||||||
"""Get a channel by name."""
|
"""Get a channel by name."""
|
||||||
return self.channels.get(name)
|
return self.channels.get(name)
|
||||||
|
|
||||||
def get_status(self) -> dict[str, Any]:
|
def get_status(self) -> dict[str, Any]:
|
||||||
"""Get status of all channels."""
|
"""Get status of all channels."""
|
||||||
return {
|
return {
|
||||||
@@ -238,7 +238,7 @@ class ChannelManager:
|
|||||||
}
|
}
|
||||||
for name, channel in self.channels.items()
|
for name, channel in self.channels.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled_channels(self) -> list[str]:
|
def enabled_channels(self) -> list[str]:
|
||||||
"""Get list of enabled channel names."""
|
"""Get list of enabled channel names."""
|
||||||
|
|||||||
@@ -12,10 +12,22 @@ try:
|
|||||||
import nh3
|
import nh3
|
||||||
from mistune import create_markdown
|
from mistune import create_markdown
|
||||||
from nio import (
|
from nio import (
|
||||||
AsyncClient, AsyncClientConfig, ContentRepositoryConfigError,
|
AsyncClient,
|
||||||
DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse,
|
AsyncClientConfig,
|
||||||
RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText,
|
ContentRepositoryConfigError,
|
||||||
RoomSendError, RoomTypingError, SyncError, UploadError,
|
DownloadError,
|
||||||
|
InviteEvent,
|
||||||
|
JoinError,
|
||||||
|
MatrixRoom,
|
||||||
|
MemoryDownloadResponse,
|
||||||
|
RoomEncryptedMedia,
|
||||||
|
RoomMessage,
|
||||||
|
RoomMessageMedia,
|
||||||
|
RoomMessageText,
|
||||||
|
RoomSendError,
|
||||||
|
RoomTypingError,
|
||||||
|
SyncError,
|
||||||
|
UploadError,
|
||||||
)
|
)
|
||||||
from nio.crypto.attachments import decrypt_attachment
|
from nio.crypto.attachments import decrypt_attachment
|
||||||
from nio.exceptions import EncryptionError
|
from nio.exceptions import EncryptionError
|
||||||
|
|||||||
@@ -5,11 +5,10 @@ import re
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
|
||||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
|
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||||
from slack_sdk.web.async_client import AsyncWebClient
|
from slack_sdk.web.async_client import AsyncWebClient
|
||||||
|
|
||||||
from slackify_markdown import slackify_markdown
|
from slackify_markdown import slackify_markdown
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from telegram import BotCommand, Update, ReplyParameters
|
from telegram import BotCommand, ReplyParameters, Update
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||||
from telegram.request import HTTPXRequest
|
from telegram.request import HTTPXRequest
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@@ -21,60 +22,60 @@ def _markdown_to_telegram_html(text: str) -> str:
|
|||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||||
code_blocks: list[str] = []
|
code_blocks: list[str] = []
|
||||||
def save_code_block(m: re.Match) -> str:
|
def save_code_block(m: re.Match) -> str:
|
||||||
code_blocks.append(m.group(1))
|
code_blocks.append(m.group(1))
|
||||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||||
|
|
||||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||||
|
|
||||||
# 2. Extract and protect inline code
|
# 2. Extract and protect inline code
|
||||||
inline_codes: list[str] = []
|
inline_codes: list[str] = []
|
||||||
def save_inline_code(m: re.Match) -> str:
|
def save_inline_code(m: re.Match) -> str:
|
||||||
inline_codes.append(m.group(1))
|
inline_codes.append(m.group(1))
|
||||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||||
|
|
||||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||||
|
|
||||||
# 3. Headers # Title -> just the title text
|
# 3. Headers # Title -> just the title text
|
||||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
# 5. Escape HTML special characters
|
# 5. Escape HTML special characters
|
||||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
|
|
||||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||||
|
|
||||||
# 7. Bold **text** or __text__
|
# 7. Bold **text** or __text__
|
||||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||||
|
|
||||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||||
|
|
||||||
# 9. Strikethrough ~~text~~
|
# 9. Strikethrough ~~text~~
|
||||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||||
|
|
||||||
# 10. Bullet lists - item -> • item
|
# 10. Bullet lists - item -> • item
|
||||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
# 11. Restore inline code with HTML tags
|
# 11. Restore inline code with HTML tags
|
||||||
for i, code in enumerate(inline_codes):
|
for i, code in enumerate(inline_codes):
|
||||||
# Escape HTML in code content
|
# Escape HTML in code content
|
||||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||||
|
|
||||||
# 12. Restore code blocks with HTML tags
|
# 12. Restore code blocks with HTML tags
|
||||||
for i, code in enumerate(code_blocks):
|
for i, code in enumerate(code_blocks):
|
||||||
# Escape HTML in code content
|
# Escape HTML in code content
|
||||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@@ -101,12 +102,12 @@ def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
|||||||
class TelegramChannel(BaseChannel):
|
class TelegramChannel(BaseChannel):
|
||||||
"""
|
"""
|
||||||
Telegram channel using long polling.
|
Telegram channel using long polling.
|
||||||
|
|
||||||
Simple and reliable - no webhook/public IP needed.
|
Simple and reliable - no webhook/public IP needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "telegram"
|
name = "telegram"
|
||||||
|
|
||||||
# Commands registered with Telegram's command menu
|
# Commands registered with Telegram's command menu
|
||||||
BOT_COMMANDS = [
|
BOT_COMMANDS = [
|
||||||
BotCommand("start", "Start the bot"),
|
BotCommand("start", "Start the bot"),
|
||||||
@@ -114,7 +115,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
BotCommand("stop", "Stop the current task"),
|
BotCommand("stop", "Stop the current task"),
|
||||||
BotCommand("help", "Show available commands"),
|
BotCommand("help", "Show available commands"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TelegramConfig,
|
config: TelegramConfig,
|
||||||
@@ -129,15 +130,15 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
self._media_group_buffers: dict[str, dict] = {}
|
self._media_group_buffers: dict[str, dict] = {}
|
||||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Telegram bot with long polling."""
|
"""Start the Telegram bot with long polling."""
|
||||||
if not self.config.token:
|
if not self.config.token:
|
||||||
logger.error("Telegram bot token not configured")
|
logger.error("Telegram bot token not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||||
@@ -145,51 +146,51 @@ class TelegramChannel(BaseChannel):
|
|||||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||||
self._app = builder.build()
|
self._app = builder.build()
|
||||||
self._app.add_error_handler(self._on_error)
|
self._app.add_error_handler(self._on_error)
|
||||||
|
|
||||||
# Add command handlers
|
# Add command handlers
|
||||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||||
|
|
||||||
# Add message handler for text, photos, voice, documents
|
# Add message handler for text, photos, voice, documents
|
||||||
self._app.add_handler(
|
self._app.add_handler(
|
||||||
MessageHandler(
|
MessageHandler(
|
||||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||||
& ~filters.COMMAND,
|
& ~filters.COMMAND,
|
||||||
self._on_message
|
self._on_message
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Starting Telegram bot (polling mode)...")
|
logger.info("Starting Telegram bot (polling mode)...")
|
||||||
|
|
||||||
# Initialize and start polling
|
# Initialize and start polling
|
||||||
await self._app.initialize()
|
await self._app.initialize()
|
||||||
await self._app.start()
|
await self._app.start()
|
||||||
|
|
||||||
# Get bot info and register command menu
|
# Get bot info and register command menu
|
||||||
bot_info = await self._app.bot.get_me()
|
bot_info = await self._app.bot.get_me()
|
||||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||||
logger.debug("Telegram bot commands registered")
|
logger.debug("Telegram bot commands registered")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to register bot commands: {}", e)
|
logger.warning("Failed to register bot commands: {}", e)
|
||||||
|
|
||||||
# Start polling (this runs until stopped)
|
# Start polling (this runs until stopped)
|
||||||
await self._app.updater.start_polling(
|
await self._app.updater.start_polling(
|
||||||
allowed_updates=["message"],
|
allowed_updates=["message"],
|
||||||
drop_pending_updates=True # Ignore old messages on startup
|
drop_pending_updates=True # Ignore old messages on startup
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep running until stopped
|
# Keep running until stopped
|
||||||
while self._running:
|
while self._running:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the Telegram bot."""
|
"""Stop the Telegram bot."""
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
# Cancel all typing indicators
|
# Cancel all typing indicators
|
||||||
for chat_id in list(self._typing_tasks):
|
for chat_id in list(self._typing_tasks):
|
||||||
self._stop_typing(chat_id)
|
self._stop_typing(chat_id)
|
||||||
@@ -198,14 +199,14 @@ class TelegramChannel(BaseChannel):
|
|||||||
task.cancel()
|
task.cancel()
|
||||||
self._media_group_tasks.clear()
|
self._media_group_tasks.clear()
|
||||||
self._media_group_buffers.clear()
|
self._media_group_buffers.clear()
|
||||||
|
|
||||||
if self._app:
|
if self._app:
|
||||||
logger.info("Stopping Telegram bot...")
|
logger.info("Stopping Telegram bot...")
|
||||||
await self._app.updater.stop()
|
await self._app.updater.stop()
|
||||||
await self._app.stop()
|
await self._app.stop()
|
||||||
await self._app.shutdown()
|
await self._app.shutdown()
|
||||||
self._app = None
|
self._app = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_media_type(path: str) -> str:
|
def _get_media_type(path: str) -> str:
|
||||||
"""Guess media type from file extension."""
|
"""Guess media type from file extension."""
|
||||||
@@ -253,7 +254,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||||
with open(media_path, 'rb') as f:
|
with open(media_path, 'rb') as f:
|
||||||
await sender(
|
await sender(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
**{param: f},
|
**{param: f},
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params
|
||||||
)
|
)
|
||||||
@@ -272,8 +273,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
html = _markdown_to_telegram_html(chunk)
|
html = _markdown_to_telegram_html(chunk)
|
||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=html,
|
text=html,
|
||||||
parse_mode="HTML",
|
parse_mode="HTML",
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params
|
||||||
)
|
)
|
||||||
@@ -281,13 +282,13 @@ class TelegramChannel(BaseChannel):
|
|||||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
try:
|
try:
|
||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=chunk,
|
text=chunk,
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params
|
||||||
)
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.error("Error sending Telegram message: {}", e2)
|
logger.error("Error sending Telegram message: {}", e2)
|
||||||
|
|
||||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
@@ -326,34 +327,34 @@ class TelegramChannel(BaseChannel):
|
|||||||
chat_id=str(update.message.chat_id),
|
chat_id=str(update.message.chat_id),
|
||||||
content=update.message.text,
|
content=update.message.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
return
|
return
|
||||||
|
|
||||||
message = update.message
|
message = update.message
|
||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
chat_id = message.chat_id
|
chat_id = message.chat_id
|
||||||
sender_id = self._sender_id(user)
|
sender_id = self._sender_id(user)
|
||||||
|
|
||||||
# Store chat_id for replies
|
# Store chat_id for replies
|
||||||
self._chat_ids[sender_id] = chat_id
|
self._chat_ids[sender_id] = chat_id
|
||||||
|
|
||||||
# Build content from text and/or media
|
# Build content from text and/or media
|
||||||
content_parts = []
|
content_parts = []
|
||||||
media_paths = []
|
media_paths = []
|
||||||
|
|
||||||
# Text content
|
# Text content
|
||||||
if message.text:
|
if message.text:
|
||||||
content_parts.append(message.text)
|
content_parts.append(message.text)
|
||||||
if message.caption:
|
if message.caption:
|
||||||
content_parts.append(message.caption)
|
content_parts.append(message.caption)
|
||||||
|
|
||||||
# Handle media files
|
# Handle media files
|
||||||
media_file = None
|
media_file = None
|
||||||
media_type = None
|
media_type = None
|
||||||
|
|
||||||
if message.photo:
|
if message.photo:
|
||||||
media_file = message.photo[-1] # Largest photo
|
media_file = message.photo[-1] # Largest photo
|
||||||
media_type = "image"
|
media_type = "image"
|
||||||
@@ -366,23 +367,23 @@ class TelegramChannel(BaseChannel):
|
|||||||
elif message.document:
|
elif message.document:
|
||||||
media_file = message.document
|
media_file = message.document
|
||||||
media_type = "file"
|
media_type = "file"
|
||||||
|
|
||||||
# Download media if present
|
# Download media if present
|
||||||
if media_file and self._app:
|
if media_file and self._app:
|
||||||
try:
|
try:
|
||||||
file = await self._app.bot.get_file(media_file.file_id)
|
file = await self._app.bot.get_file(media_file.file_id)
|
||||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
||||||
|
|
||||||
# Save to workspace/media/
|
# Save to workspace/media/
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
media_dir = Path.home() / ".nanobot" / "media"
|
media_dir = Path.home() / ".nanobot" / "media"
|
||||||
media_dir.mkdir(parents=True, exist_ok=True)
|
media_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||||
await file.download_to_drive(str(file_path))
|
await file.download_to_drive(str(file_path))
|
||||||
|
|
||||||
media_paths.append(str(file_path))
|
media_paths.append(str(file_path))
|
||||||
|
|
||||||
# Handle voice transcription
|
# Handle voice transcription
|
||||||
if media_type == "voice" or media_type == "audio":
|
if media_type == "voice" or media_type == "audio":
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||||
@@ -395,16 +396,16 @@ class TelegramChannel(BaseChannel):
|
|||||||
content_parts.append(f"[{media_type}: {file_path}]")
|
content_parts.append(f"[{media_type}: {file_path}]")
|
||||||
else:
|
else:
|
||||||
content_parts.append(f"[{media_type}: {file_path}]")
|
content_parts.append(f"[{media_type}: {file_path}]")
|
||||||
|
|
||||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to download media: {}", e)
|
logger.error("Failed to download media: {}", e)
|
||||||
content_parts.append(f"[{media_type}: download failed]")
|
content_parts.append(f"[{media_type}: download failed]")
|
||||||
|
|
||||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||||
|
|
||||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
|
|
||||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||||
@@ -428,10 +429,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
if key not in self._media_group_tasks:
|
if key not in self._media_group_tasks:
|
||||||
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Start typing indicator before processing
|
# Start typing indicator before processing
|
||||||
self._start_typing(str_chat_id)
|
self._start_typing(str_chat_id)
|
||||||
|
|
||||||
# Forward to the message bus
|
# Forward to the message bus
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
@@ -446,7 +447,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
"is_group": message.chat.type != "private"
|
"is_group": message.chat.type != "private"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _flush_media_group(self, key: str) -> None:
|
async def _flush_media_group(self, key: str) -> None:
|
||||||
"""Wait briefly, then forward buffered media-group as one turn."""
|
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||||
try:
|
try:
|
||||||
@@ -467,13 +468,13 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Cancel any existing typing task for this chat
|
# Cancel any existing typing task for this chat
|
||||||
self._stop_typing(chat_id)
|
self._stop_typing(chat_id)
|
||||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||||
|
|
||||||
def _stop_typing(self, chat_id: str) -> None:
|
def _stop_typing(self, chat_id: str) -> None:
|
||||||
"""Stop the typing indicator for a chat."""
|
"""Stop the typing indicator for a chat."""
|
||||||
task = self._typing_tasks.pop(chat_id, None)
|
task = self._typing_tasks.pop(chat_id, None)
|
||||||
if task and not task.done():
|
if task and not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
async def _typing_loop(self, chat_id: str) -> None:
|
async def _typing_loop(self, chat_id: str) -> None:
|
||||||
"""Repeatedly send 'typing' action until cancelled."""
|
"""Repeatedly send 'typing' action until cancelled."""
|
||||||
try:
|
try:
|
||||||
@@ -484,7 +485,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||||
|
|
||||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Log polling / handler errors instead of silently swallowing them."""
|
"""Log polling / handler errors instead of silently swallowing them."""
|
||||||
logger.error("Telegram error: {}", context.error)
|
logger.error("Telegram error: {}", context.error)
|
||||||
@@ -498,6 +499,6 @@ class TelegramChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
if mime_type in ext_map:
|
if mime_type in ext_map:
|
||||||
return ext_map[mime_type]
|
return ext_map[mime_type]
|
||||||
|
|
||||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||||
return type_map.get(media_type, "")
|
return type_map.get(media_type, "")
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -29,17 +28,17 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
self._ws = None
|
self._ws = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
bridge_url = self.config.bridge_url
|
bridge_url = self.config.bridge_url
|
||||||
|
|
||||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
async with websockets.connect(bridge_url) as ws:
|
async with websockets.connect(bridge_url) as ws:
|
||||||
@@ -49,40 +48,40 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||||
self._connected = True
|
self._connected = True
|
||||||
logger.info("Connected to WhatsApp bridge")
|
logger.info("Connected to WhatsApp bridge")
|
||||||
|
|
||||||
# Listen for messages
|
# Listen for messages
|
||||||
async for message in ws:
|
async for message in ws:
|
||||||
try:
|
try:
|
||||||
await self._handle_bridge_message(message)
|
await self._handle_bridge_message(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error handling bridge message: {}", e)
|
logger.error("Error handling bridge message: {}", e)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._ws = None
|
self._ws = None
|
||||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||||
|
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting in 5 seconds...")
|
logger.info("Reconnecting in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the WhatsApp channel."""
|
"""Stop the WhatsApp channel."""
|
||||||
self._running = False
|
self._running = False
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
if self._ws:
|
if self._ws:
|
||||||
await self._ws.close()
|
await self._ws.close()
|
||||||
self._ws = None
|
self._ws = None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through WhatsApp."""
|
"""Send a message through WhatsApp."""
|
||||||
if not self._ws or not self._connected:
|
if not self._ws or not self._connected:
|
||||||
logger.warning("WhatsApp bridge not connected")
|
logger.warning("WhatsApp bridge not connected")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = {
|
payload = {
|
||||||
"type": "send",
|
"type": "send",
|
||||||
@@ -92,7 +91,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending WhatsApp message: {}", e)
|
logger.error("Error sending WhatsApp message: {}", e)
|
||||||
|
|
||||||
async def _handle_bridge_message(self, raw: str) -> None:
|
async def _handle_bridge_message(self, raw: str) -> None:
|
||||||
"""Handle a message from the bridge."""
|
"""Handle a message from the bridge."""
|
||||||
try:
|
try:
|
||||||
@@ -100,9 +99,9 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||||
return
|
return
|
||||||
|
|
||||||
msg_type = data.get("type")
|
msg_type = data.get("type")
|
||||||
|
|
||||||
if msg_type == "message":
|
if msg_type == "message":
|
||||||
# Incoming message from WhatsApp
|
# Incoming message from WhatsApp
|
||||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||||
@@ -139,20 +138,20 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
"is_group": data.get("isGroup", False)
|
"is_group": data.get("isGroup", False)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
elif msg_type == "status":
|
elif msg_type == "status":
|
||||||
# Connection status update
|
# Connection status update
|
||||||
status = data.get("status")
|
status = data.get("status")
|
||||||
logger.info("WhatsApp status: {}", status)
|
logger.info("WhatsApp status: {}", status)
|
||||||
|
|
||||||
if status == "connected":
|
if status == "connected":
|
||||||
self._connected = True
|
self._connected = True
|
||||||
elif status == "disconnected":
|
elif status == "disconnected":
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
elif msg_type == "qr":
|
elif msg_type == "qr":
|
||||||
# QR code for authentication
|
# QR code for authentication
|
||||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||||
|
|
||||||
elif msg_type == "error":
|
elif msg_type == "error":
|
||||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||||
|
|||||||
@@ -2,23 +2,22 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
from pathlib import Path
|
|
||||||
import select
|
import select
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.formatted_text import HTML
|
||||||
|
from prompt_toolkit.history import FileHistory
|
||||||
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from prompt_toolkit import PromptSession
|
from nanobot import __logo__, __version__
|
||||||
from prompt_toolkit.formatted_text import HTML
|
|
||||||
from prompt_toolkit.history import FileHistory
|
|
||||||
from prompt_toolkit.patch_stdout import patch_stdout
|
|
||||||
|
|
||||||
from nanobot import __version__, __logo__
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.utils.helpers import sync_workspace_templates
|
from nanobot.utils.helpers import sync_workspace_templates
|
||||||
|
|
||||||
@@ -160,9 +159,9 @@ def onboard():
|
|||||||
from nanobot.config.loader import get_config_path, load_config, save_config
|
from nanobot.config.loader import get_config_path, load_config, save_config
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.utils.helpers import get_workspace_path
|
from nanobot.utils.helpers import get_workspace_path
|
||||||
|
|
||||||
config_path = get_config_path()
|
config_path = get_config_path()
|
||||||
|
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||||
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||||
@@ -178,16 +177,16 @@ def onboard():
|
|||||||
else:
|
else:
|
||||||
save_config(Config())
|
save_config(Config())
|
||||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||||
|
|
||||||
# Create workspace
|
# Create workspace
|
||||||
workspace = get_workspace_path()
|
workspace = get_workspace_path()
|
||||||
|
|
||||||
if not workspace.exists():
|
if not workspace.exists():
|
||||||
workspace.mkdir(parents=True, exist_ok=True)
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||||
|
|
||||||
sync_workspace_templates(workspace)
|
sync_workspace_templates(workspace)
|
||||||
|
|
||||||
console.print(f"\n{__logo__} nanobot is ready!")
|
console.print(f"\n{__logo__} nanobot is ready!")
|
||||||
console.print("\nNext steps:")
|
console.print("\nNext steps:")
|
||||||
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
|
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
|
||||||
@@ -201,9 +200,9 @@ def onboard():
|
|||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
|
from nanobot.providers.custom_provider import CustomProvider
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model)
|
||||||
@@ -248,31 +247,31 @@ def gateway(
|
|||||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
||||||
):
|
):
|
||||||
"""Start the nanobot gateway."""
|
"""Start the nanobot gateway."""
|
||||||
from nanobot.config.loader import load_config, get_data_dir
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.config.loader import get_data_dir, load_config
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronJob
|
from nanobot.cron.types import CronJob
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
provider = _make_provider(config)
|
||||||
session_manager = SessionManager(config.workspace_path)
|
session_manager = SessionManager(config.workspace_path)
|
||||||
|
|
||||||
# Create cron service first (callback set after agent creation)
|
# Create cron service first (callback set after agent creation)
|
||||||
cron_store_path = get_data_dir() / "cron" / "jobs.json"
|
cron_store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
cron = CronService(cron_store_path)
|
cron = CronService(cron_store_path)
|
||||||
|
|
||||||
# Create agent with cron service
|
# Create agent with cron service
|
||||||
agent = AgentLoop(
|
agent = AgentLoop(
|
||||||
bus=bus,
|
bus=bus,
|
||||||
@@ -291,7 +290,7 @@ def gateway(
|
|||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
channels_config=config.channels,
|
channels_config=config.channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set cron callback (needs agent)
|
# Set cron callback (needs agent)
|
||||||
async def on_cron_job(job: CronJob) -> str | None:
|
async def on_cron_job(job: CronJob) -> str | None:
|
||||||
"""Execute a cron job through the agent."""
|
"""Execute a cron job through the agent."""
|
||||||
@@ -310,7 +309,7 @@ def gateway(
|
|||||||
))
|
))
|
||||||
return response
|
return response
|
||||||
cron.on_job = on_cron_job
|
cron.on_job = on_cron_job
|
||||||
|
|
||||||
# Create channel manager
|
# Create channel manager
|
||||||
channels = ChannelManager(config, bus)
|
channels = ChannelManager(config, bus)
|
||||||
|
|
||||||
@@ -364,18 +363,18 @@ def gateway(
|
|||||||
interval_s=hb_cfg.interval_s,
|
interval_s=hb_cfg.interval_s,
|
||||||
enabled=hb_cfg.enabled,
|
enabled=hb_cfg.enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if channels.enabled_channels:
|
if channels.enabled_channels:
|
||||||
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
||||||
else:
|
else:
|
||||||
console.print("[yellow]Warning: No channels enabled[/yellow]")
|
console.print("[yellow]Warning: No channels enabled[/yellow]")
|
||||||
|
|
||||||
cron_status = cron.status()
|
cron_status = cron.status()
|
||||||
if cron_status["jobs"] > 0:
|
if cron_status["jobs"] > 0:
|
||||||
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||||
|
|
||||||
async def run():
|
async def run():
|
||||||
try:
|
try:
|
||||||
await cron.start()
|
await cron.start()
|
||||||
@@ -392,7 +391,7 @@ def gateway(
|
|||||||
cron.stop()
|
cron.stop()
|
||||||
agent.stop()
|
agent.stop()
|
||||||
await channels.stop_all()
|
await channels.stop_all()
|
||||||
|
|
||||||
asyncio.run(run())
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
@@ -411,15 +410,16 @@ def agent(
|
|||||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||||
):
|
):
|
||||||
"""Interact with the agent directly."""
|
"""Interact with the agent directly."""
|
||||||
from nanobot.config.loader import load_config, get_data_dir
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.loader import get_data_dir, load_config
|
||||||
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
provider = _make_provider(config)
|
||||||
|
|
||||||
@@ -431,7 +431,7 @@ def agent(
|
|||||||
logger.enable("nanobot")
|
logger.enable("nanobot")
|
||||||
else:
|
else:
|
||||||
logger.disable("nanobot")
|
logger.disable("nanobot")
|
||||||
|
|
||||||
agent_loop = AgentLoop(
|
agent_loop = AgentLoop(
|
||||||
bus=bus,
|
bus=bus,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@@ -448,7 +448,7 @@ def agent(
|
|||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
channels_config=config.channels,
|
channels_config=config.channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
# Show spinner when logs are off (no output to miss); skip when logs are on
|
||||||
def _thinking_ctx():
|
def _thinking_ctx():
|
||||||
if logs:
|
if logs:
|
||||||
@@ -624,7 +624,7 @@ def channels_status():
|
|||||||
"✓" if mc.enabled else "✗",
|
"✓" if mc.enabled else "✗",
|
||||||
mc_base
|
mc_base
|
||||||
)
|
)
|
||||||
|
|
||||||
# Telegram
|
# Telegram
|
||||||
tg = config.channels.telegram
|
tg = config.channels.telegram
|
||||||
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
|
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
|
||||||
@@ -677,57 +677,57 @@ def _get_bridge_dir() -> Path:
|
|||||||
"""Get the bridge directory, setting it up if needed."""
|
"""Get the bridge directory, setting it up if needed."""
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
# User's bridge location
|
# User's bridge location
|
||||||
user_bridge = Path.home() / ".nanobot" / "bridge"
|
user_bridge = Path.home() / ".nanobot" / "bridge"
|
||||||
|
|
||||||
# Check if already built
|
# Check if already built
|
||||||
if (user_bridge / "dist" / "index.js").exists():
|
if (user_bridge / "dist" / "index.js").exists():
|
||||||
return user_bridge
|
return user_bridge
|
||||||
|
|
||||||
# Check for npm
|
# Check for npm
|
||||||
if not shutil.which("npm"):
|
if not shutil.which("npm"):
|
||||||
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
|
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
# Find source bridge: first check package data, then source dir
|
# Find source bridge: first check package data, then source dir
|
||||||
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
|
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
|
||||||
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
|
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
|
||||||
|
|
||||||
source = None
|
source = None
|
||||||
if (pkg_bridge / "package.json").exists():
|
if (pkg_bridge / "package.json").exists():
|
||||||
source = pkg_bridge
|
source = pkg_bridge
|
||||||
elif (src_bridge / "package.json").exists():
|
elif (src_bridge / "package.json").exists():
|
||||||
source = src_bridge
|
source = src_bridge
|
||||||
|
|
||||||
if not source:
|
if not source:
|
||||||
console.print("[red]Bridge source not found.[/red]")
|
console.print("[red]Bridge source not found.[/red]")
|
||||||
console.print("Try reinstalling: pip install --force-reinstall nanobot")
|
console.print("Try reinstalling: pip install --force-reinstall nanobot")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
console.print(f"{__logo__} Setting up bridge...")
|
console.print(f"{__logo__} Setting up bridge...")
|
||||||
|
|
||||||
# Copy to user directory
|
# Copy to user directory
|
||||||
user_bridge.parent.mkdir(parents=True, exist_ok=True)
|
user_bridge.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if user_bridge.exists():
|
if user_bridge.exists():
|
||||||
shutil.rmtree(user_bridge)
|
shutil.rmtree(user_bridge)
|
||||||
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
|
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
|
||||||
|
|
||||||
# Install and build
|
# Install and build
|
||||||
try:
|
try:
|
||||||
console.print(" Installing dependencies...")
|
console.print(" Installing dependencies...")
|
||||||
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
|
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
|
||||||
|
|
||||||
console.print(" Building...")
|
console.print(" Building...")
|
||||||
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
|
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
|
||||||
|
|
||||||
console.print("[green]✓[/green] Bridge ready\n")
|
console.print("[green]✓[/green] Bridge ready\n")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
console.print(f"[red]Build failed: {e}[/red]")
|
console.print(f"[red]Build failed: {e}[/red]")
|
||||||
if e.stderr:
|
if e.stderr:
|
||||||
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
|
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
return user_bridge
|
return user_bridge
|
||||||
|
|
||||||
|
|
||||||
@@ -735,18 +735,19 @@ def _get_bridge_dir() -> Path:
|
|||||||
def channels_login():
|
def channels_login():
|
||||||
"""Link device via QR code."""
|
"""Link device via QR code."""
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from nanobot.config.loader import load_config
|
from nanobot.config.loader import load_config
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
bridge_dir = _get_bridge_dir()
|
bridge_dir = _get_bridge_dir()
|
||||||
|
|
||||||
console.print(f"{__logo__} Starting bridge...")
|
console.print(f"{__logo__} Starting bridge...")
|
||||||
console.print("Scan the QR code to connect.\n")
|
console.print("Scan the QR code to connect.\n")
|
||||||
|
|
||||||
env = {**os.environ}
|
env = {**os.environ}
|
||||||
if config.channels.whatsapp.bridge_token:
|
if config.channels.whatsapp.bridge_token:
|
||||||
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
@@ -770,23 +771,23 @@ def cron_list(
|
|||||||
"""List scheduled jobs."""
|
"""List scheduled jobs."""
|
||||||
from nanobot.config.loader import get_data_dir
|
from nanobot.config.loader import get_data_dir
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
service = CronService(store_path)
|
service = CronService(store_path)
|
||||||
|
|
||||||
jobs = service.list_jobs(include_disabled=all)
|
jobs = service.list_jobs(include_disabled=all)
|
||||||
|
|
||||||
if not jobs:
|
if not jobs:
|
||||||
console.print("No scheduled jobs.")
|
console.print("No scheduled jobs.")
|
||||||
return
|
return
|
||||||
|
|
||||||
table = Table(title="Scheduled Jobs")
|
table = Table(title="Scheduled Jobs")
|
||||||
table.add_column("ID", style="cyan")
|
table.add_column("ID", style="cyan")
|
||||||
table.add_column("Name")
|
table.add_column("Name")
|
||||||
table.add_column("Schedule")
|
table.add_column("Schedule")
|
||||||
table.add_column("Status")
|
table.add_column("Status")
|
||||||
table.add_column("Next Run")
|
table.add_column("Next Run")
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime as _dt
|
from datetime import datetime as _dt
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
@@ -798,7 +799,7 @@ def cron_list(
|
|||||||
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
|
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
|
||||||
else:
|
else:
|
||||||
sched = "one-time"
|
sched = "one-time"
|
||||||
|
|
||||||
# Format next run
|
# Format next run
|
||||||
next_run = ""
|
next_run = ""
|
||||||
if job.state.next_run_at_ms:
|
if job.state.next_run_at_ms:
|
||||||
@@ -808,11 +809,11 @@ def cron_list(
|
|||||||
next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
|
next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
|
||||||
except Exception:
|
except Exception:
|
||||||
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
|
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
|
||||||
|
|
||||||
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
|
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
|
||||||
|
|
||||||
table.add_row(job.id, job.name, sched, status, next_run)
|
table.add_row(job.id, job.name, sched, status, next_run)
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
@@ -832,7 +833,7 @@ def cron_add(
|
|||||||
from nanobot.config.loader import get_data_dir
|
from nanobot.config.loader import get_data_dir
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronSchedule
|
from nanobot.cron.types import CronSchedule
|
||||||
|
|
||||||
if tz and not cron_expr:
|
if tz and not cron_expr:
|
||||||
console.print("[red]Error: --tz can only be used with --cron[/red]")
|
console.print("[red]Error: --tz can only be used with --cron[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
@@ -849,10 +850,10 @@ def cron_add(
|
|||||||
else:
|
else:
|
||||||
console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
|
console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
service = CronService(store_path)
|
service = CronService(store_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
job = service.add_job(
|
job = service.add_job(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -876,10 +877,10 @@ def cron_remove(
|
|||||||
"""Remove a scheduled job."""
|
"""Remove a scheduled job."""
|
||||||
from nanobot.config.loader import get_data_dir
|
from nanobot.config.loader import get_data_dir
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
service = CronService(store_path)
|
service = CronService(store_path)
|
||||||
|
|
||||||
if service.remove_job(job_id):
|
if service.remove_job(job_id):
|
||||||
console.print(f"[green]✓[/green] Removed job {job_id}")
|
console.print(f"[green]✓[/green] Removed job {job_id}")
|
||||||
else:
|
else:
|
||||||
@@ -894,10 +895,10 @@ def cron_enable(
|
|||||||
"""Enable or disable a job."""
|
"""Enable or disable a job."""
|
||||||
from nanobot.config.loader import get_data_dir
|
from nanobot.config.loader import get_data_dir
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
service = CronService(store_path)
|
service = CronService(store_path)
|
||||||
|
|
||||||
job = service.enable_job(job_id, enabled=not disable)
|
job = service.enable_job(job_id, enabled=not disable)
|
||||||
if job:
|
if job:
|
||||||
status = "disabled" if disable else "enabled"
|
status = "disabled" if disable else "enabled"
|
||||||
@@ -913,11 +914,12 @@ def cron_run(
|
|||||||
):
|
):
|
||||||
"""Manually run a job."""
|
"""Manually run a job."""
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nanobot.config.loader import load_config, get_data_dir
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.loader import get_data_dir, load_config
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronJob
|
from nanobot.cron.types import CronJob
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
logger.disable("nanobot")
|
logger.disable("nanobot")
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
@@ -975,7 +977,7 @@ def cron_run(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def status():
|
def status():
|
||||||
"""Show nanobot status."""
|
"""Show nanobot status."""
|
||||||
from nanobot.config.loader import load_config, get_config_path
|
from nanobot.config.loader import get_config_path, load_config
|
||||||
|
|
||||||
config_path = get_config_path()
|
config_path = get_config_path()
|
||||||
config = load_config()
|
config = load_config()
|
||||||
@@ -990,7 +992,7 @@ def status():
|
|||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS
|
||||||
|
|
||||||
console.print(f"Model: {config.agents.defaults.model}")
|
console.print(f"Model: {config.agents.defaults.model}")
|
||||||
|
|
||||||
# Check API keys from registry
|
# Check API keys from registry
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(config.providers, spec.name, None)
|
p = getattr(config.providers, spec.name, None)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Configuration module for nanobot."""
|
"""Configuration module for nanobot."""
|
||||||
|
|
||||||
from nanobot.config.loader import load_config, get_config_path
|
from nanobot.config.loader import get_config_path, load_config
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
__all__ = ["Config", "load_config", "get_config_path"]
|
__all__ = ["Config", "load_config", "get_config_path"]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from pydantic.alias_generators import to_camel
|
from pydantic.alias_generators import to_camel
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|||||||
@@ -21,17 +21,18 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
|||||||
"""Compute next run time in ms."""
|
"""Compute next run time in ms."""
|
||||||
if schedule.kind == "at":
|
if schedule.kind == "at":
|
||||||
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
|
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
|
||||||
|
|
||||||
if schedule.kind == "every":
|
if schedule.kind == "every":
|
||||||
if not schedule.every_ms or schedule.every_ms <= 0:
|
if not schedule.every_ms or schedule.every_ms <= 0:
|
||||||
return None
|
return None
|
||||||
# Next interval from now
|
# Next interval from now
|
||||||
return now_ms + schedule.every_ms
|
return now_ms + schedule.every_ms
|
||||||
|
|
||||||
if schedule.kind == "cron" and schedule.expr:
|
if schedule.kind == "cron" and schedule.expr:
|
||||||
try:
|
try:
|
||||||
from croniter import croniter
|
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from croniter import croniter
|
||||||
# Use caller-provided reference time for deterministic scheduling
|
# Use caller-provided reference time for deterministic scheduling
|
||||||
base_time = now_ms / 1000
|
base_time = now_ms / 1000
|
||||||
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
||||||
@@ -41,7 +42,7 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
|||||||
return int(next_dt.timestamp() * 1000)
|
return int(next_dt.timestamp() * 1000)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -61,7 +62,7 @@ def _validate_schedule_for_add(schedule: CronSchedule) -> None:
|
|||||||
|
|
||||||
class CronService:
|
class CronService:
|
||||||
"""Service for managing and executing scheduled jobs."""
|
"""Service for managing and executing scheduled jobs."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
store_path: Path,
|
store_path: Path,
|
||||||
@@ -72,12 +73,12 @@ class CronService:
|
|||||||
self._store: CronStore | None = None
|
self._store: CronStore | None = None
|
||||||
self._timer_task: asyncio.Task | None = None
|
self._timer_task: asyncio.Task | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
def _load_store(self) -> CronStore:
|
def _load_store(self) -> CronStore:
|
||||||
"""Load jobs from disk."""
|
"""Load jobs from disk."""
|
||||||
if self._store:
|
if self._store:
|
||||||
return self._store
|
return self._store
|
||||||
|
|
||||||
if self.store_path.exists():
|
if self.store_path.exists():
|
||||||
try:
|
try:
|
||||||
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||||
@@ -117,16 +118,16 @@ class CronService:
|
|||||||
self._store = CronStore()
|
self._store = CronStore()
|
||||||
else:
|
else:
|
||||||
self._store = CronStore()
|
self._store = CronStore()
|
||||||
|
|
||||||
return self._store
|
return self._store
|
||||||
|
|
||||||
def _save_store(self) -> None:
|
def _save_store(self) -> None:
|
||||||
"""Save jobs to disk."""
|
"""Save jobs to disk."""
|
||||||
if not self._store:
|
if not self._store:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"version": self._store.version,
|
"version": self._store.version,
|
||||||
"jobs": [
|
"jobs": [
|
||||||
@@ -161,9 +162,9 @@ class CronService:
|
|||||||
for j in self._store.jobs
|
for j in self._store.jobs
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the cron service."""
|
"""Start the cron service."""
|
||||||
self._running = True
|
self._running = True
|
||||||
@@ -172,14 +173,14 @@ class CronService:
|
|||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
|
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the cron service."""
|
"""Stop the cron service."""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._timer_task:
|
if self._timer_task:
|
||||||
self._timer_task.cancel()
|
self._timer_task.cancel()
|
||||||
self._timer_task = None
|
self._timer_task = None
|
||||||
|
|
||||||
def _recompute_next_runs(self) -> None:
|
def _recompute_next_runs(self) -> None:
|
||||||
"""Recompute next run times for all enabled jobs."""
|
"""Recompute next run times for all enabled jobs."""
|
||||||
if not self._store:
|
if not self._store:
|
||||||
@@ -188,73 +189,73 @@ class CronService:
|
|||||||
for job in self._store.jobs:
|
for job in self._store.jobs:
|
||||||
if job.enabled:
|
if job.enabled:
|
||||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
|
job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
|
||||||
|
|
||||||
def _get_next_wake_ms(self) -> int | None:
|
def _get_next_wake_ms(self) -> int | None:
|
||||||
"""Get the earliest next run time across all jobs."""
|
"""Get the earliest next run time across all jobs."""
|
||||||
if not self._store:
|
if not self._store:
|
||||||
return None
|
return None
|
||||||
times = [j.state.next_run_at_ms for j in self._store.jobs
|
times = [j.state.next_run_at_ms for j in self._store.jobs
|
||||||
if j.enabled and j.state.next_run_at_ms]
|
if j.enabled and j.state.next_run_at_ms]
|
||||||
return min(times) if times else None
|
return min(times) if times else None
|
||||||
|
|
||||||
def _arm_timer(self) -> None:
|
def _arm_timer(self) -> None:
|
||||||
"""Schedule the next timer tick."""
|
"""Schedule the next timer tick."""
|
||||||
if self._timer_task:
|
if self._timer_task:
|
||||||
self._timer_task.cancel()
|
self._timer_task.cancel()
|
||||||
|
|
||||||
next_wake = self._get_next_wake_ms()
|
next_wake = self._get_next_wake_ms()
|
||||||
if not next_wake or not self._running:
|
if not next_wake or not self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
delay_ms = max(0, next_wake - _now_ms())
|
delay_ms = max(0, next_wake - _now_ms())
|
||||||
delay_s = delay_ms / 1000
|
delay_s = delay_ms / 1000
|
||||||
|
|
||||||
async def tick():
|
async def tick():
|
||||||
await asyncio.sleep(delay_s)
|
await asyncio.sleep(delay_s)
|
||||||
if self._running:
|
if self._running:
|
||||||
await self._on_timer()
|
await self._on_timer()
|
||||||
|
|
||||||
self._timer_task = asyncio.create_task(tick())
|
self._timer_task = asyncio.create_task(tick())
|
||||||
|
|
||||||
async def _on_timer(self) -> None:
|
async def _on_timer(self) -> None:
|
||||||
"""Handle timer tick - run due jobs."""
|
"""Handle timer tick - run due jobs."""
|
||||||
if not self._store:
|
if not self._store:
|
||||||
return
|
return
|
||||||
|
|
||||||
now = _now_ms()
|
now = _now_ms()
|
||||||
due_jobs = [
|
due_jobs = [
|
||||||
j for j in self._store.jobs
|
j for j in self._store.jobs
|
||||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||||
]
|
]
|
||||||
|
|
||||||
for job in due_jobs:
|
for job in due_jobs:
|
||||||
await self._execute_job(job)
|
await self._execute_job(job)
|
||||||
|
|
||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
|
|
||||||
async def _execute_job(self, job: CronJob) -> None:
|
async def _execute_job(self, job: CronJob) -> None:
|
||||||
"""Execute a single job."""
|
"""Execute a single job."""
|
||||||
start_ms = _now_ms()
|
start_ms = _now_ms()
|
||||||
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
|
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = None
|
response = None
|
||||||
if self.on_job:
|
if self.on_job:
|
||||||
response = await self.on_job(job)
|
response = await self.on_job(job)
|
||||||
|
|
||||||
job.state.last_status = "ok"
|
job.state.last_status = "ok"
|
||||||
job.state.last_error = None
|
job.state.last_error = None
|
||||||
logger.info("Cron: job '{}' completed", job.name)
|
logger.info("Cron: job '{}' completed", job.name)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
job.state.last_status = "error"
|
job.state.last_status = "error"
|
||||||
job.state.last_error = str(e)
|
job.state.last_error = str(e)
|
||||||
logger.error("Cron: job '{}' failed: {}", job.name, e)
|
logger.error("Cron: job '{}' failed: {}", job.name, e)
|
||||||
|
|
||||||
job.state.last_run_at_ms = start_ms
|
job.state.last_run_at_ms = start_ms
|
||||||
job.updated_at_ms = _now_ms()
|
job.updated_at_ms = _now_ms()
|
||||||
|
|
||||||
# Handle one-shot jobs
|
# Handle one-shot jobs
|
||||||
if job.schedule.kind == "at":
|
if job.schedule.kind == "at":
|
||||||
if job.delete_after_run:
|
if job.delete_after_run:
|
||||||
@@ -265,15 +266,15 @@ class CronService:
|
|||||||
else:
|
else:
|
||||||
# Compute next run
|
# Compute next run
|
||||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||||
|
|
||||||
# ========== Public API ==========
|
# ========== Public API ==========
|
||||||
|
|
||||||
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
||||||
"""List all jobs."""
|
"""List all jobs."""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
|
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
|
||||||
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
|
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
|
||||||
|
|
||||||
def add_job(
|
def add_job(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@@ -288,7 +289,7 @@ class CronService:
|
|||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
_validate_schedule_for_add(schedule)
|
_validate_schedule_for_add(schedule)
|
||||||
now = _now_ms()
|
now = _now_ms()
|
||||||
|
|
||||||
job = CronJob(
|
job = CronJob(
|
||||||
id=str(uuid.uuid4())[:8],
|
id=str(uuid.uuid4())[:8],
|
||||||
name=name,
|
name=name,
|
||||||
@@ -306,28 +307,28 @@ class CronService:
|
|||||||
updated_at_ms=now,
|
updated_at_ms=now,
|
||||||
delete_after_run=delete_after_run,
|
delete_after_run=delete_after_run,
|
||||||
)
|
)
|
||||||
|
|
||||||
store.jobs.append(job)
|
store.jobs.append(job)
|
||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
|
|
||||||
logger.info("Cron: added job '{}' ({})", name, job.id)
|
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
def remove_job(self, job_id: str) -> bool:
|
def remove_job(self, job_id: str) -> bool:
|
||||||
"""Remove a job by ID."""
|
"""Remove a job by ID."""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
before = len(store.jobs)
|
before = len(store.jobs)
|
||||||
store.jobs = [j for j in store.jobs if j.id != job_id]
|
store.jobs = [j for j in store.jobs if j.id != job_id]
|
||||||
removed = len(store.jobs) < before
|
removed = len(store.jobs) < before
|
||||||
|
|
||||||
if removed:
|
if removed:
|
||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
logger.info("Cron: removed job {}", job_id)
|
logger.info("Cron: removed job {}", job_id)
|
||||||
|
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
|
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
|
||||||
"""Enable or disable a job."""
|
"""Enable or disable a job."""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
@@ -343,7 +344,7 @@ class CronService:
|
|||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
return job
|
return job
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||||
"""Manually run a job."""
|
"""Manually run a job."""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
@@ -356,7 +357,7 @@ class CronService:
|
|||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def status(self) -> dict:
|
def status(self) -> dict:
|
||||||
"""Get service status."""
|
"""Get service status."""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class LLMResponse:
|
|||||||
finish_reason: str = "stop"
|
finish_reason: str = "stop"
|
||||||
usage: dict[str, int] = field(default_factory=dict)
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_tool_calls(self) -> bool:
|
def has_tool_calls(self) -> bool:
|
||||||
"""Check if response contains tool calls."""
|
"""Check if response contains tool calls."""
|
||||||
@@ -35,7 +35,7 @@ class LLMProvider(ABC):
|
|||||||
Implementations should handle the specifics of each provider's API
|
Implementations should handle the specifics of each provider's API
|
||||||
while maintaining a consistent interface.
|
while maintaining a consistent interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
@@ -79,7 +79,7 @@ class LLMProvider(ABC):
|
|||||||
|
|
||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
@@ -103,7 +103,7 @@ class LLMProvider(ABC):
|
|||||||
LLMResponse with content and/or tool calls.
|
LLMResponse with content and/or tool calls.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
"""Get the default model for this provider."""
|
"""Get the default model for this provider."""
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
"""LiteLLM provider implementation for multi-provider support."""
|
"""LiteLLM provider implementation for multi-provider support."""
|
||||||
|
|
||||||
import json
|
|
||||||
import json_repair
|
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import json_repair
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
|
|
||||||
|
|
||||||
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
||||||
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
@@ -32,10 +30,10 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
a unified interface. Provider-specific logic is driven by the registry
|
a unified interface. Provider-specific logic is driven by the registry
|
||||||
(see providers/registry.py) — no if-elif chains needed here.
|
(see providers/registry.py) — no if-elif chains needed here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
default_model: str = "anthropic/claude-opus-4-5",
|
default_model: str = "anthropic/claude-opus-4-5",
|
||||||
extra_headers: dict[str, str] | None = None,
|
extra_headers: dict[str, str] | None = None,
|
||||||
@@ -44,24 +42,24 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
self.extra_headers = extra_headers or {}
|
self.extra_headers = extra_headers or {}
|
||||||
|
|
||||||
# Detect gateway / local deployment.
|
# Detect gateway / local deployment.
|
||||||
# provider_name (from config key) is the primary signal;
|
# provider_name (from config key) is the primary signal;
|
||||||
# api_key / api_base are fallback for auto-detection.
|
# api_key / api_base are fallback for auto-detection.
|
||||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||||
|
|
||||||
# Configure environment variables
|
# Configure environment variables
|
||||||
if api_key:
|
if api_key:
|
||||||
self._setup_env(api_key, api_base, default_model)
|
self._setup_env(api_key, api_base, default_model)
|
||||||
|
|
||||||
if api_base:
|
if api_base:
|
||||||
litellm.api_base = api_base
|
litellm.api_base = api_base
|
||||||
|
|
||||||
# Disable LiteLLM logging noise
|
# Disable LiteLLM logging noise
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||||
"""Set environment variables based on detected provider."""
|
"""Set environment variables based on detected provider."""
|
||||||
spec = self._gateway or find_by_model(model)
|
spec = self._gateway or find_by_model(model)
|
||||||
@@ -85,7 +83,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
resolved = env_val.replace("{api_key}", api_key)
|
resolved = env_val.replace("{api_key}", api_key)
|
||||||
resolved = resolved.replace("{api_base}", effective_base)
|
resolved = resolved.replace("{api_base}", effective_base)
|
||||||
os.environ.setdefault(env_name, resolved)
|
os.environ.setdefault(env_name, resolved)
|
||||||
|
|
||||||
def _resolve_model(self, model: str) -> str:
|
def _resolve_model(self, model: str) -> str:
|
||||||
"""Resolve model name by applying provider/gateway prefixes."""
|
"""Resolve model name by applying provider/gateway prefixes."""
|
||||||
if self._gateway:
|
if self._gateway:
|
||||||
@@ -96,7 +94,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if prefix and not model.startswith(f"{prefix}/"):
|
if prefix and not model.startswith(f"{prefix}/"):
|
||||||
model = f"{prefix}/{model}"
|
model = f"{prefix}/{model}"
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# Standard mode: auto-prefix for known providers
|
# Standard mode: auto-prefix for known providers
|
||||||
spec = find_by_model(model)
|
spec = find_by_model(model)
|
||||||
if spec and spec.litellm_prefix:
|
if spec and spec.litellm_prefix:
|
||||||
@@ -115,7 +113,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if prefix.lower().replace("-", "_") != spec_name:
|
if prefix.lower().replace("-", "_") != spec_name:
|
||||||
return model
|
return model
|
||||||
return f"{canonical_prefix}/{remainder}"
|
return f"{canonical_prefix}/{remainder}"
|
||||||
|
|
||||||
def _supports_cache_control(self, model: str) -> bool:
|
def _supports_cache_control(self, model: str) -> bool:
|
||||||
"""Return True when the provider supports cache_control on content blocks."""
|
"""Return True when the provider supports cache_control on content blocks."""
|
||||||
if self._gateway is not None:
|
if self._gateway is not None:
|
||||||
@@ -158,7 +156,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if pattern in model_lower:
|
if pattern in model_lower:
|
||||||
kwargs.update(overrides)
|
kwargs.update(overrides)
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||||
@@ -181,14 +179,14 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request via LiteLLM.
|
Send a chat completion request via LiteLLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts with 'role' and 'content'.
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
tools: Optional list of tool definitions in OpenAI format.
|
tools: Optional list of tool definitions in OpenAI format.
|
||||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
||||||
max_tokens: Maximum tokens in response.
|
max_tokens: Maximum tokens in response.
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMResponse with content and/or tool calls.
|
LLMResponse with content and/or tool calls.
|
||||||
"""
|
"""
|
||||||
@@ -201,33 +199,33 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||||
max_tokens = max(1, max_tokens)
|
max_tokens = max(1, max_tokens)
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||||
self._apply_model_overrides(model, kwargs)
|
self._apply_model_overrides(model, kwargs)
|
||||||
|
|
||||||
# Pass api_key directly — more reliable than env vars alone
|
# Pass api_key directly — more reliable than env vars alone
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
kwargs["api_key"] = self.api_key
|
kwargs["api_key"] = self.api_key
|
||||||
|
|
||||||
# Pass api_base for custom endpoints
|
# Pass api_base for custom endpoints
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
kwargs["api_base"] = self.api_base
|
kwargs["api_base"] = self.api_base
|
||||||
|
|
||||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
||||||
if self.extra_headers:
|
if self.extra_headers:
|
||||||
kwargs["extra_headers"] = self.extra_headers
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = "auto"
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(**kwargs)
|
response = await acompletion(**kwargs)
|
||||||
return self._parse_response(response)
|
return self._parse_response(response)
|
||||||
@@ -237,12 +235,12 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
content=f"Error calling LLM: {str(e)}",
|
content=f"Error calling LLM: {str(e)}",
|
||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_response(self, response: Any) -> LLMResponse:
|
def _parse_response(self, response: Any) -> LLMResponse:
|
||||||
"""Parse LiteLLM response into our standard format."""
|
"""Parse LiteLLM response into our standard format."""
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
message = choice.message
|
message = choice.message
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
for tc in message.tool_calls:
|
for tc in message.tool_calls:
|
||||||
@@ -250,13 +248,13 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
args = tc.function.arguments
|
args = tc.function.arguments
|
||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=_short_tool_id(),
|
id=_short_tool_id(),
|
||||||
name=tc.function.name,
|
name=tc.function.name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
))
|
))
|
||||||
|
|
||||||
usage = {}
|
usage = {}
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
usage = {
|
usage = {
|
||||||
@@ -264,9 +262,9 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"completion_tokens": response.usage.completion_tokens,
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
"total_tokens": response.usage.total_tokens,
|
"total_tokens": response.usage.total_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
@@ -274,7 +272,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
"""Get the default model."""
|
"""Get the default model."""
|
||||||
return self.default_model
|
return self.default_model
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from oauth_cli_kit import get_token as get_codex_token
|
from oauth_cli_kit import get_token as get_codex_token
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -11,33 +10,33 @@ from loguru import logger
|
|||||||
class GroqTranscriptionProvider:
|
class GroqTranscriptionProvider:
|
||||||
"""
|
"""
|
||||||
Voice transcription provider using Groq's Whisper API.
|
Voice transcription provider using Groq's Whisper API.
|
||||||
|
|
||||||
Groq offers extremely fast transcription with a generous free tier.
|
Groq offers extremely fast transcription with a generous free tier.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||||
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
|
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
|
||||||
|
|
||||||
async def transcribe(self, file_path: str | Path) -> str:
|
async def transcribe(self, file_path: str | Path) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe an audio file using Groq.
|
Transcribe an audio file using Groq.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the audio file.
|
file_path: Path to the audio file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Transcribed text.
|
Transcribed text.
|
||||||
"""
|
"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
logger.warning("Groq API key not configured for transcription")
|
logger.warning("Groq API key not configured for transcription")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
logger.error("Audio file not found: {}", file_path)
|
logger.error("Audio file not found: {}", file_path)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
|
|||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.api_url,
|
self.api_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
files=files,
|
files=files,
|
||||||
timeout=60.0
|
timeout=60.0
|
||||||
)
|
)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return data.get("text", "")
|
return data.get("text", "")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Groq transcription error: {}", e)
|
logger.error("Groq transcription error: {}", e)
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Session management module."""
|
"""Session management module."""
|
||||||
|
|
||||||
from nanobot.session.manager import SessionManager, Session
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
__all__ = ["SessionManager", "Session"]
|
__all__ = ["SessionManager", "Session"]
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -30,7 +30,7 @@ class Session:
|
|||||||
updated_at: datetime = field(default_factory=datetime.now)
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||||
|
|
||||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||||
"""Add a message to the session."""
|
"""Add a message to the session."""
|
||||||
msg = {
|
msg = {
|
||||||
@@ -41,7 +41,7 @@ class Session:
|
|||||||
}
|
}
|
||||||
self.messages.append(msg)
|
self.messages.append(msg)
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
||||||
unconsolidated = self.messages[self.last_consolidated:]
|
unconsolidated = self.messages[self.last_consolidated:]
|
||||||
@@ -61,7 +61,7 @@ class Session:
|
|||||||
entry[k] = m[k]
|
entry[k] = m[k]
|
||||||
out.append(entry)
|
out.append(entry)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all messages and reset session to initial state."""
|
"""Clear all messages and reset session to initial state."""
|
||||||
self.messages = []
|
self.messages = []
|
||||||
@@ -81,7 +81,7 @@ class SessionManager:
|
|||||||
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
||||||
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
|
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
|
||||||
self._cache: dict[str, Session] = {}
|
self._cache: dict[str, Session] = {}
|
||||||
|
|
||||||
def _get_session_path(self, key: str) -> Path:
|
def _get_session_path(self, key: str) -> Path:
|
||||||
"""Get the file path for a session."""
|
"""Get the file path for a session."""
|
||||||
safe_key = safe_filename(key.replace(":", "_"))
|
safe_key = safe_filename(key.replace(":", "_"))
|
||||||
@@ -91,27 +91,27 @@ class SessionManager:
|
|||||||
"""Legacy global session path (~/.nanobot/sessions/)."""
|
"""Legacy global session path (~/.nanobot/sessions/)."""
|
||||||
safe_key = safe_filename(key.replace(":", "_"))
|
safe_key = safe_filename(key.replace(":", "_"))
|
||||||
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
|
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
|
||||||
|
|
||||||
def get_or_create(self, key: str) -> Session:
|
def get_or_create(self, key: str) -> Session:
|
||||||
"""
|
"""
|
||||||
Get an existing session or create a new one.
|
Get an existing session or create a new one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: Session key (usually channel:chat_id).
|
key: Session key (usually channel:chat_id).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The session.
|
The session.
|
||||||
"""
|
"""
|
||||||
if key in self._cache:
|
if key in self._cache:
|
||||||
return self._cache[key]
|
return self._cache[key]
|
||||||
|
|
||||||
session = self._load(key)
|
session = self._load(key)
|
||||||
if session is None:
|
if session is None:
|
||||||
session = Session(key=key)
|
session = Session(key=key)
|
||||||
|
|
||||||
self._cache[key] = session
|
self._cache[key] = session
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def _load(self, key: str) -> Session | None:
|
def _load(self, key: str) -> Session | None:
|
||||||
"""Load a session from disk."""
|
"""Load a session from disk."""
|
||||||
path = self._get_session_path(key)
|
path = self._get_session_path(key)
|
||||||
@@ -158,7 +158,7 @@ class SessionManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load session {}: {}", key, e)
|
logger.warning("Failed to load session {}: {}", key, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def save(self, session: Session) -> None:
|
def save(self, session: Session) -> None:
|
||||||
"""Save a session to disk."""
|
"""Save a session to disk."""
|
||||||
path = self._get_session_path(session.key)
|
path = self._get_session_path(session.key)
|
||||||
@@ -177,20 +177,20 @@ class SessionManager:
|
|||||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
self._cache[session.key] = session
|
self._cache[session.key] = session
|
||||||
|
|
||||||
def invalidate(self, key: str) -> None:
|
def invalidate(self, key: str) -> None:
|
||||||
"""Remove a session from the in-memory cache."""
|
"""Remove a session from the in-memory cache."""
|
||||||
self._cache.pop(key, None)
|
self._cache.pop(key, None)
|
||||||
|
|
||||||
def list_sessions(self) -> list[dict[str, Any]]:
|
def list_sessions(self) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
List all sessions.
|
List all sessions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of session info dicts.
|
List of session info dicts.
|
||||||
"""
|
"""
|
||||||
sessions = []
|
sessions = []
|
||||||
|
|
||||||
for path in self.sessions_dir.glob("*.jsonl"):
|
for path in self.sessions_dir.glob("*.jsonl"):
|
||||||
try:
|
try:
|
||||||
# Read just the metadata line
|
# Read just the metadata line
|
||||||
@@ -208,5 +208,5 @@ class SessionManager:
|
|||||||
})
|
})
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
from nanobot.utils.helpers import ensure_dir, get_workspace_path, get_data_path
|
from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
|
||||||
|
|
||||||
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
|
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def ensure_dir(path: Path) -> Path:
|
def ensure_dir(path: Path) -> Path:
|
||||||
|
|||||||
Reference in New Issue
Block a user