style: unify code formatting and import order

- Remove trailing whitespace and normalize blank lines
- Unify string quotes and line breaks for long lines
- Sort imports alphabetically across modules
This commit is contained in:
JK_Lu
2026-02-28 20:55:43 +08:00
parent bfc2fa88f3
commit 977ca725f2
33 changed files with 574 additions and 581 deletions

View File

@@ -1,7 +1,7 @@
"""Agent core module.""" """Agent core module."""
from nanobot.agent.loop import AgentLoop
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader

View File

@@ -14,15 +14,15 @@ from nanobot.agent.skills import SkillsLoader
class ContextBuilder: class ContextBuilder:
"""Builds the context (system prompt + messages) for the agent.""" """Builds the context (system prompt + messages) for the agent."""
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.workspace = workspace self.workspace = workspace
self.memory = MemoryStore(workspace) self.memory = MemoryStore(workspace)
self.skills = SkillsLoader(workspace) self.skills = SkillsLoader(workspace)
def build_system_prompt(self, skill_names: list[str] | None = None) -> str: def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills.""" """Build the system prompt from identity, bootstrap files, memory, and skills."""
parts = [self._get_identity()] parts = [self._get_identity()]
@@ -51,13 +51,13 @@ Skills with available="false" need dependencies installed first - you can try in
{skills_summary}""") {skills_summary}""")
return "\n\n---\n\n".join(parts) return "\n\n---\n\n".join(parts)
def _get_identity(self) -> str: def _get_identity(self) -> str:
"""Get the core identity section.""" """Get the core identity section."""
workspace_path = str(self.workspace.expanduser().resolve()) workspace_path = str(self.workspace.expanduser().resolve())
system = platform.system() system = platform.system()
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
return f"""# nanobot 🐈 return f"""# nanobot 🐈
You are nanobot, a helpful AI assistant. You are nanobot, a helpful AI assistant.
@@ -89,19 +89,19 @@ Reply directly with text for conversations. Only use the 'message' tool to send
if channel and chat_id: if channel and chat_id:
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
def _load_bootstrap_files(self) -> str: def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace.""" """Load all bootstrap files from workspace."""
parts = [] parts = []
for filename in self.BOOTSTRAP_FILES: for filename in self.BOOTSTRAP_FILES:
file_path = self.workspace / filename file_path = self.workspace / filename
if file_path.exists(): if file_path.exists():
content = file_path.read_text(encoding="utf-8") content = file_path.read_text(encoding="utf-8")
parts.append(f"## {filename}\n\n{content}") parts.append(f"## {filename}\n\n{content}")
return "\n\n".join(parts) if parts else "" return "\n\n".join(parts) if parts else ""
def build_messages( def build_messages(
self, self,
history: list[dict[str, Any]], history: list[dict[str, Any]],
@@ -123,7 +123,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
"""Build user message content with optional base64-encoded images.""" """Build user message content with optional base64-encoded images."""
if not media: if not media:
return text return text
images = [] images = []
for path in media: for path in media:
p = Path(path) p = Path(path)
@@ -132,11 +132,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send
continue continue
b64 = base64.b64encode(p.read_bytes()).decode() b64 = base64.b64encode(p.read_bytes()).decode()
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
if not images: if not images:
return text return text
return images + [{"type": "text", "text": text}] return images + [{"type": "text", "text": text}]
def add_tool_result( def add_tool_result(
self, messages: list[dict[str, Any]], self, messages: list[dict[str, Any]],
tool_call_id: str, tool_name: str, result: str, tool_call_id: str, tool_name: str, result: str,
@@ -144,7 +144,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
"""Add a tool result to the message list.""" """Add a tool result to the message list."""
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result}) messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
return messages return messages
def add_assistant_message( def add_assistant_message(
self, messages: list[dict[str, Any]], self, messages: list[dict[str, Any]],
content: str | None, content: str | None,

View File

@@ -13,28 +13,28 @@ BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
class SkillsLoader: class SkillsLoader:
""" """
Loader for agent skills. Loader for agent skills.
Skills are markdown files (SKILL.md) that teach the agent how to use Skills are markdown files (SKILL.md) that teach the agent how to use
specific tools or perform certain tasks. specific tools or perform certain tasks.
""" """
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None): def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
self.workspace = workspace self.workspace = workspace
self.workspace_skills = workspace / "skills" self.workspace_skills = workspace / "skills"
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]: def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
""" """
List all available skills. List all available skills.
Args: Args:
filter_unavailable: If True, filter out skills with unmet requirements. filter_unavailable: If True, filter out skills with unmet requirements.
Returns: Returns:
List of skill info dicts with 'name', 'path', 'source'. List of skill info dicts with 'name', 'path', 'source'.
""" """
skills = [] skills = []
# Workspace skills (highest priority) # Workspace skills (highest priority)
if self.workspace_skills.exists(): if self.workspace_skills.exists():
for skill_dir in self.workspace_skills.iterdir(): for skill_dir in self.workspace_skills.iterdir():
@@ -42,7 +42,7 @@ class SkillsLoader:
skill_file = skill_dir / "SKILL.md" skill_file = skill_dir / "SKILL.md"
if skill_file.exists(): if skill_file.exists():
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"}) skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
# Built-in skills # Built-in skills
if self.builtin_skills and self.builtin_skills.exists(): if self.builtin_skills and self.builtin_skills.exists():
for skill_dir in self.builtin_skills.iterdir(): for skill_dir in self.builtin_skills.iterdir():
@@ -50,19 +50,19 @@ class SkillsLoader:
skill_file = skill_dir / "SKILL.md" skill_file = skill_dir / "SKILL.md"
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills): if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"}) skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
# Filter by requirements # Filter by requirements
if filter_unavailable: if filter_unavailable:
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))] return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
return skills return skills
def load_skill(self, name: str) -> str | None: def load_skill(self, name: str) -> str | None:
""" """
Load a skill by name. Load a skill by name.
Args: Args:
name: Skill name (directory name). name: Skill name (directory name).
Returns: Returns:
Skill content or None if not found. Skill content or None if not found.
""" """
@@ -70,22 +70,22 @@ class SkillsLoader:
workspace_skill = self.workspace_skills / name / "SKILL.md" workspace_skill = self.workspace_skills / name / "SKILL.md"
if workspace_skill.exists(): if workspace_skill.exists():
return workspace_skill.read_text(encoding="utf-8") return workspace_skill.read_text(encoding="utf-8")
# Check built-in # Check built-in
if self.builtin_skills: if self.builtin_skills:
builtin_skill = self.builtin_skills / name / "SKILL.md" builtin_skill = self.builtin_skills / name / "SKILL.md"
if builtin_skill.exists(): if builtin_skill.exists():
return builtin_skill.read_text(encoding="utf-8") return builtin_skill.read_text(encoding="utf-8")
return None return None
def load_skills_for_context(self, skill_names: list[str]) -> str: def load_skills_for_context(self, skill_names: list[str]) -> str:
""" """
Load specific skills for inclusion in agent context. Load specific skills for inclusion in agent context.
Args: Args:
skill_names: List of skill names to load. skill_names: List of skill names to load.
Returns: Returns:
Formatted skills content. Formatted skills content.
""" """
@@ -95,26 +95,26 @@ class SkillsLoader:
if content: if content:
content = self._strip_frontmatter(content) content = self._strip_frontmatter(content)
parts.append(f"### Skill: {name}\n\n{content}") parts.append(f"### Skill: {name}\n\n{content}")
return "\n\n---\n\n".join(parts) if parts else "" return "\n\n---\n\n".join(parts) if parts else ""
def build_skills_summary(self) -> str: def build_skills_summary(self) -> str:
""" """
Build a summary of all skills (name, description, path, availability). Build a summary of all skills (name, description, path, availability).
This is used for progressive loading - the agent can read the full This is used for progressive loading - the agent can read the full
skill content using read_file when needed. skill content using read_file when needed.
Returns: Returns:
XML-formatted skills summary. XML-formatted skills summary.
""" """
all_skills = self.list_skills(filter_unavailable=False) all_skills = self.list_skills(filter_unavailable=False)
if not all_skills: if not all_skills:
return "" return ""
def escape_xml(s: str) -> str: def escape_xml(s: str) -> str:
return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
lines = ["<skills>"] lines = ["<skills>"]
for s in all_skills: for s in all_skills:
name = escape_xml(s["name"]) name = escape_xml(s["name"])
@@ -122,23 +122,23 @@ class SkillsLoader:
desc = escape_xml(self._get_skill_description(s["name"])) desc = escape_xml(self._get_skill_description(s["name"]))
skill_meta = self._get_skill_meta(s["name"]) skill_meta = self._get_skill_meta(s["name"])
available = self._check_requirements(skill_meta) available = self._check_requirements(skill_meta)
lines.append(f" <skill available=\"{str(available).lower()}\">") lines.append(f" <skill available=\"{str(available).lower()}\">")
lines.append(f" <name>{name}</name>") lines.append(f" <name>{name}</name>")
lines.append(f" <description>{desc}</description>") lines.append(f" <description>{desc}</description>")
lines.append(f" <location>{path}</location>") lines.append(f" <location>{path}</location>")
# Show missing requirements for unavailable skills # Show missing requirements for unavailable skills
if not available: if not available:
missing = self._get_missing_requirements(skill_meta) missing = self._get_missing_requirements(skill_meta)
if missing: if missing:
lines.append(f" <requires>{escape_xml(missing)}</requires>") lines.append(f" <requires>{escape_xml(missing)}</requires>")
lines.append(f" </skill>") lines.append(" </skill>")
lines.append("</skills>") lines.append("</skills>")
return "\n".join(lines) return "\n".join(lines)
def _get_missing_requirements(self, skill_meta: dict) -> str: def _get_missing_requirements(self, skill_meta: dict) -> str:
"""Get a description of missing requirements.""" """Get a description of missing requirements."""
missing = [] missing = []
@@ -150,14 +150,14 @@ class SkillsLoader:
if not os.environ.get(env): if not os.environ.get(env):
missing.append(f"ENV: {env}") missing.append(f"ENV: {env}")
return ", ".join(missing) return ", ".join(missing)
def _get_skill_description(self, name: str) -> str: def _get_skill_description(self, name: str) -> str:
"""Get the description of a skill from its frontmatter.""" """Get the description of a skill from its frontmatter."""
meta = self.get_skill_metadata(name) meta = self.get_skill_metadata(name)
if meta and meta.get("description"): if meta and meta.get("description"):
return meta["description"] return meta["description"]
return name # Fallback to skill name return name # Fallback to skill name
def _strip_frontmatter(self, content: str) -> str: def _strip_frontmatter(self, content: str) -> str:
"""Remove YAML frontmatter from markdown content.""" """Remove YAML frontmatter from markdown content."""
if content.startswith("---"): if content.startswith("---"):
@@ -165,7 +165,7 @@ class SkillsLoader:
if match: if match:
return content[match.end():].strip() return content[match.end():].strip()
return content return content
def _parse_nanobot_metadata(self, raw: str) -> dict: def _parse_nanobot_metadata(self, raw: str) -> dict:
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys).""" """Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
try: try:
@@ -173,7 +173,7 @@ class SkillsLoader:
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {} return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return {} return {}
def _check_requirements(self, skill_meta: dict) -> bool: def _check_requirements(self, skill_meta: dict) -> bool:
"""Check if skill requirements are met (bins, env vars).""" """Check if skill requirements are met (bins, env vars)."""
requires = skill_meta.get("requires", {}) requires = skill_meta.get("requires", {})
@@ -184,12 +184,12 @@ class SkillsLoader:
if not os.environ.get(env): if not os.environ.get(env):
return False return False
return True return True
def _get_skill_meta(self, name: str) -> dict: def _get_skill_meta(self, name: str) -> dict:
"""Get nanobot metadata for a skill (cached in frontmatter).""" """Get nanobot metadata for a skill (cached in frontmatter)."""
meta = self.get_skill_metadata(name) or {} meta = self.get_skill_metadata(name) or {}
return self._parse_nanobot_metadata(meta.get("metadata", "")) return self._parse_nanobot_metadata(meta.get("metadata", ""))
def get_always_skills(self) -> list[str]: def get_always_skills(self) -> list[str]:
"""Get skills marked as always=true that meet requirements.""" """Get skills marked as always=true that meet requirements."""
result = [] result = []
@@ -199,21 +199,21 @@ class SkillsLoader:
if skill_meta.get("always") or meta.get("always"): if skill_meta.get("always") or meta.get("always"):
result.append(s["name"]) result.append(s["name"])
return result return result
def get_skill_metadata(self, name: str) -> dict | None: def get_skill_metadata(self, name: str) -> dict | None:
""" """
Get metadata from a skill's frontmatter. Get metadata from a skill's frontmatter.
Args: Args:
name: Skill name. name: Skill name.
Returns: Returns:
Metadata dict or None. Metadata dict or None.
""" """
content = self.load_skill(name) content = self.load_skill(name)
if not content: if not content:
return None return None
if content.startswith("---"): if content.startswith("---"):
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
if match: if match:
@@ -224,5 +224,5 @@ class SkillsLoader:
key, value = line.split(":", 1) key, value = line.split(":", 1)
metadata[key.strip()] = value.strip().strip('"\'') metadata[key.strip()] = value.strip().strip('"\'')
return metadata return metadata
return None return None

View File

@@ -8,18 +8,19 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
class SubagentManager: class SubagentManager:
"""Manages background subagent execution.""" """Manages background subagent execution."""
def __init__( def __init__(
self, self,
provider: LLMProvider, provider: LLMProvider,
@@ -44,7 +45,7 @@ class SubagentManager:
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self._running_tasks: dict[str, asyncio.Task[None]] = {} self._running_tasks: dict[str, asyncio.Task[None]] = {}
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...} self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
async def spawn( async def spawn(
self, self,
task: str, task: str,
@@ -73,10 +74,10 @@ class SubagentManager:
del self._session_tasks[session_key] del self._session_tasks[session_key]
bg_task.add_done_callback(_cleanup) bg_task.add_done_callback(_cleanup)
logger.info("Spawned subagent [{}]: {}", task_id, display_label) logger.info("Spawned subagent [{}]: {}", task_id, display_label)
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes." return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
async def _run_subagent( async def _run_subagent(
self, self,
task_id: str, task_id: str,
@@ -86,7 +87,7 @@ class SubagentManager:
) -> None: ) -> None:
"""Execute the subagent task and announce the result.""" """Execute the subagent task and announce the result."""
logger.info("Subagent [{}] starting task: {}", task_id, label) logger.info("Subagent [{}] starting task: {}", task_id, label)
try: try:
# Build subagent tools (no message tool, no spawn tool) # Build subagent tools (no message tool, no spawn tool)
tools = ToolRegistry() tools = ToolRegistry()
@@ -103,22 +104,22 @@ class SubagentManager:
)) ))
tools.register(WebSearchTool(api_key=self.brave_api_key)) tools.register(WebSearchTool(api_key=self.brave_api_key))
tools.register(WebFetchTool()) tools.register(WebFetchTool())
# Build messages with subagent-specific prompt # Build messages with subagent-specific prompt
system_prompt = self._build_subagent_prompt(task) system_prompt = self._build_subagent_prompt(task)
messages: list[dict[str, Any]] = [ messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": task}, {"role": "user", "content": task},
] ]
# Run agent loop (limited iterations) # Run agent loop (limited iterations)
max_iterations = 15 max_iterations = 15
iteration = 0 iteration = 0
final_result: str | None = None final_result: str | None = None
while iteration < max_iterations: while iteration < max_iterations:
iteration += 1 iteration += 1
response = await self.provider.chat( response = await self.provider.chat(
messages=messages, messages=messages,
tools=tools.get_definitions(), tools=tools.get_definitions(),
@@ -126,7 +127,7 @@ class SubagentManager:
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
if response.has_tool_calls: if response.has_tool_calls:
# Add assistant message with tool calls # Add assistant message with tool calls
tool_call_dicts = [ tool_call_dicts = [
@@ -145,7 +146,7 @@ class SubagentManager:
"content": response.content or "", "content": response.content or "",
"tool_calls": tool_call_dicts, "tool_calls": tool_call_dicts,
}) })
# Execute tools # Execute tools
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False) args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
@@ -160,18 +161,18 @@ class SubagentManager:
else: else:
final_result = response.content final_result = response.content
break break
if final_result is None: if final_result is None:
final_result = "Task completed but no final response was generated." final_result = "Task completed but no final response was generated."
logger.info("Subagent [{}] completed successfully", task_id) logger.info("Subagent [{}] completed successfully", task_id)
await self._announce_result(task_id, label, task, final_result, origin, "ok") await self._announce_result(task_id, label, task, final_result, origin, "ok")
except Exception as e: except Exception as e:
error_msg = f"Error: {str(e)}" error_msg = f"Error: {str(e)}"
logger.error("Subagent [{}] failed: {}", task_id, e) logger.error("Subagent [{}] failed: {}", task_id, e)
await self._announce_result(task_id, label, task, error_msg, origin, "error") await self._announce_result(task_id, label, task, error_msg, origin, "error")
async def _announce_result( async def _announce_result(
self, self,
task_id: str, task_id: str,
@@ -183,7 +184,7 @@ class SubagentManager:
) -> None: ) -> None:
"""Announce the subagent result to the main agent via the message bus.""" """Announce the subagent result to the main agent via the message bus."""
status_text = "completed successfully" if status == "ok" else "failed" status_text = "completed successfully" if status == "ok" else "failed"
announce_content = f"""[Subagent '{label}' {status_text}] announce_content = f"""[Subagent '{label}' {status_text}]
Task: {task} Task: {task}
@@ -192,7 +193,7 @@ Result:
{result} {result}
Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs.""" Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs."""
# Inject as system message to trigger main agent # Inject as system message to trigger main agent
msg = InboundMessage( msg = InboundMessage(
channel="system", channel="system",
@@ -200,14 +201,14 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
chat_id=f"{origin['channel']}:{origin['chat_id']}", chat_id=f"{origin['channel']}:{origin['chat_id']}",
content=announce_content, content=announce_content,
) )
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id']) logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
def _build_subagent_prompt(self, task: str) -> str: def _build_subagent_prompt(self, task: str) -> str:
"""Build a focused system prompt for the subagent.""" """Build a focused system prompt for the subagent."""
from datetime import datetime
import time as _time import time as _time
from datetime import datetime
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
tz = _time.strftime("%Z") or "UTC" tz = _time.strftime("%Z") or "UTC"
@@ -240,7 +241,7 @@ Your workspace is at: {self.workspace}
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed) Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
When you have completed the task, provide a clear summary of your findings or actions.""" When you have completed the task, provide a clear summary of your findings or actions."""
async def cancel_by_session(self, session_key: str) -> int: async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents for the given session. Returns count cancelled.""" """Cancel all subagents for the given session. Returns count cancelled."""
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, []) tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])

View File

@@ -7,11 +7,11 @@ from typing import Any
class Tool(ABC): class Tool(ABC):
""" """
Abstract base class for agent tools. Abstract base class for agent tools.
Tools are capabilities that the agent can use to interact with Tools are capabilities that the agent can use to interact with
the environment, such as reading files, executing commands, etc. the environment, such as reading files, executing commands, etc.
""" """
_TYPE_MAP = { _TYPE_MAP = {
"string": str, "string": str,
"integer": int, "integer": int,
@@ -20,33 +20,33 @@ class Tool(ABC):
"array": list, "array": list,
"object": dict, "object": dict,
} }
@property @property
@abstractmethod @abstractmethod
def name(self) -> str: def name(self) -> str:
"""Tool name used in function calls.""" """Tool name used in function calls."""
pass pass
@property @property
@abstractmethod @abstractmethod
def description(self) -> str: def description(self) -> str:
"""Description of what the tool does.""" """Description of what the tool does."""
pass pass
@property @property
@abstractmethod @abstractmethod
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
"""JSON Schema for tool parameters.""" """JSON Schema for tool parameters."""
pass pass
@abstractmethod @abstractmethod
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
""" """
Execute the tool with given parameters. Execute the tool with given parameters.
Args: Args:
**kwargs: Tool-specific parameters. **kwargs: Tool-specific parameters.
Returns: Returns:
String result of the tool execution. String result of the tool execution.
""" """
@@ -63,7 +63,7 @@ class Tool(ABC):
t, label = schema.get("type"), path or "parameter" t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]): if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"] return [f"{label} should be {t}"]
errors = [] errors = []
if "enum" in schema and val not in schema["enum"]: if "enum" in schema and val not in schema["enum"]:
errors.append(f"{label} must be one of {schema['enum']}") errors.append(f"{label} must be one of {schema['enum']}")
@@ -84,12 +84,14 @@ class Tool(ABC):
errors.append(f"missing required {path + '.' + k if path else k}") errors.append(f"missing required {path + '.' + k if path else k}")
for k, v in val.items(): for k, v in val.items():
if k in props: if k in props:
errors.extend(self._validate(v, props[k], path + '.' + k if path else k)) errors.extend(self._validate(v, props[k], path + "." + k if path else k))
if t == "array" and "items" in schema: if t == "array" and "items" in schema:
for i, item in enumerate(val): for i, item in enumerate(val):
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")) errors.extend(
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
)
return errors return errors
def to_schema(self) -> dict[str, Any]: def to_schema(self) -> dict[str, Any]:
"""Convert tool to OpenAI function schema format.""" """Convert tool to OpenAI function schema format."""
return { return {
@@ -98,5 +100,5 @@ class Tool(ABC):
"name": self.name, "name": self.name,
"description": self.description, "description": self.description,
"parameters": self.parameters, "parameters": self.parameters,
} },
} }

View File

@@ -9,25 +9,25 @@ from nanobot.cron.types import CronSchedule
class CronTool(Tool): class CronTool(Tool):
"""Tool to schedule reminders and recurring tasks.""" """Tool to schedule reminders and recurring tasks."""
def __init__(self, cron_service: CronService): def __init__(self, cron_service: CronService):
self._cron = cron_service self._cron = cron_service
self._channel = "" self._channel = ""
self._chat_id = "" self._chat_id = ""
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current session context for delivery.""" """Set the current session context for delivery."""
self._channel = channel self._channel = channel
self._chat_id = chat_id self._chat_id = chat_id
@property @property
def name(self) -> str: def name(self) -> str:
return "cron" return "cron"
@property @property
def description(self) -> str: def description(self) -> str:
return "Schedule reminders and recurring tasks. Actions: add, list, remove." return "Schedule reminders and recurring tasks. Actions: add, list, remove."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
@@ -36,36 +36,30 @@ class CronTool(Tool):
"action": { "action": {
"type": "string", "type": "string",
"enum": ["add", "list", "remove"], "enum": ["add", "list", "remove"],
"description": "Action to perform" "description": "Action to perform",
},
"message": {
"type": "string",
"description": "Reminder message (for add)"
}, },
"message": {"type": "string", "description": "Reminder message (for add)"},
"every_seconds": { "every_seconds": {
"type": "integer", "type": "integer",
"description": "Interval in seconds (for recurring tasks)" "description": "Interval in seconds (for recurring tasks)",
}, },
"cron_expr": { "cron_expr": {
"type": "string", "type": "string",
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)" "description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
}, },
"tz": { "tz": {
"type": "string", "type": "string",
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')" "description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
}, },
"at": { "at": {
"type": "string", "type": "string",
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')" "description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
}, },
"job_id": { "job_id": {"type": "string", "description": "Job ID (for remove)"},
"type": "string",
"description": "Job ID (for remove)"
}
}, },
"required": ["action"] "required": ["action"],
} }
async def execute( async def execute(
self, self,
action: str, action: str,
@@ -75,7 +69,7 @@ class CronTool(Tool):
tz: str | None = None, tz: str | None = None,
at: str | None = None, at: str | None = None,
job_id: str | None = None, job_id: str | None = None,
**kwargs: Any **kwargs: Any,
) -> str: ) -> str:
if action == "add": if action == "add":
return self._add_job(message, every_seconds, cron_expr, tz, at) return self._add_job(message, every_seconds, cron_expr, tz, at)
@@ -84,7 +78,7 @@ class CronTool(Tool):
elif action == "remove": elif action == "remove":
return self._remove_job(job_id) return self._remove_job(job_id)
return f"Unknown action: {action}" return f"Unknown action: {action}"
def _add_job( def _add_job(
self, self,
message: str, message: str,
@@ -101,11 +95,12 @@ class CronTool(Tool):
return "Error: tz can only be used with cron_expr" return "Error: tz can only be used with cron_expr"
if tz: if tz:
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
try: try:
ZoneInfo(tz) ZoneInfo(tz)
except (KeyError, Exception): except (KeyError, Exception):
return f"Error: unknown timezone '{tz}'" return f"Error: unknown timezone '{tz}'"
# Build schedule # Build schedule
delete_after = False delete_after = False
if every_seconds: if every_seconds:
@@ -114,13 +109,14 @@ class CronTool(Tool):
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz) schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
elif at: elif at:
from datetime import datetime from datetime import datetime
dt = datetime.fromisoformat(at) dt = datetime.fromisoformat(at)
at_ms = int(dt.timestamp() * 1000) at_ms = int(dt.timestamp() * 1000)
schedule = CronSchedule(kind="at", at_ms=at_ms) schedule = CronSchedule(kind="at", at_ms=at_ms)
delete_after = True delete_after = True
else: else:
return "Error: either every_seconds, cron_expr, or at is required" return "Error: either every_seconds, cron_expr, or at is required"
job = self._cron.add_job( job = self._cron.add_job(
name=message[:30], name=message[:30],
schedule=schedule, schedule=schedule,
@@ -131,14 +127,14 @@ class CronTool(Tool):
delete_after_run=delete_after, delete_after_run=delete_after,
) )
return f"Created job '{job.name}' (id: {job.id})" return f"Created job '{job.name}' (id: {job.id})"
def _list_jobs(self) -> str: def _list_jobs(self) -> str:
jobs = self._cron.list_jobs() jobs = self._cron.list_jobs()
if not jobs: if not jobs:
return "No scheduled jobs." return "No scheduled jobs."
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs] lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
return "Scheduled jobs:\n" + "\n".join(lines) return "Scheduled jobs:\n" + "\n".join(lines)
def _remove_job(self, job_id: str | None) -> str: def _remove_job(self, job_id: str | None) -> str:
if not job_id: if not job_id:
return "Error: job_id is required for remove" return "Error: job_id is required for remove"

View File

@@ -7,7 +7,9 @@ from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path: def _resolve_path(
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
) -> Path:
"""Resolve path against workspace (if relative) and enforce directory restriction.""" """Resolve path against workspace (if relative) and enforce directory restriction."""
p = Path(path).expanduser() p = Path(path).expanduser()
if not p.is_absolute() and workspace: if not p.is_absolute() and workspace:
@@ -31,24 +33,19 @@ class ReadFileTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "read_file" return "read_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Read the contents of a file at the given path." return "Read the contents of a file at the given path."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "The file path to read"}},
"path": { "required": ["path"],
"type": "string",
"description": "The file path to read"
}
},
"required": ["path"]
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
@@ -75,28 +72,22 @@ class WriteFileTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "write_file" return "write_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Write content to a file at the given path. Creates parent directories if needed." return "Write content to a file at the given path. Creates parent directories if needed."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"path": { "path": {"type": "string", "description": "The file path to write to"},
"type": "string", "content": {"type": "string", "description": "The content to write"},
"description": "The file path to write to"
},
"content": {
"type": "string",
"description": "The content to write"
}
}, },
"required": ["path", "content"] "required": ["path", "content"],
} }
async def execute(self, path: str, content: str, **kwargs: Any) -> str: async def execute(self, path: str, content: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
@@ -119,32 +110,23 @@ class EditFileTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "edit_file" return "edit_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"path": { "path": {"type": "string", "description": "The file path to edit"},
"type": "string", "old_text": {"type": "string", "description": "The exact text to find and replace"},
"description": "The file path to edit" "new_text": {"type": "string", "description": "The text to replace with"},
},
"old_text": {
"type": "string",
"description": "The exact text to find and replace"
},
"new_text": {
"type": "string",
"description": "The text to replace with"
}
}, },
"required": ["path", "old_text", "new_text"] "required": ["path", "old_text", "new_text"],
} }
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
@@ -184,13 +166,19 @@ class EditFileTool(Tool):
best_ratio, best_start = ratio, i best_ratio, best_start = ratio, i
if best_ratio > 0.5: if best_ratio > 0.5:
diff = "\n".join(difflib.unified_diff( diff = "\n".join(
old_lines, lines[best_start : best_start + window], difflib.unified_diff(
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})", old_lines,
lineterm="", lines[best_start : best_start + window],
)) fromfile="old_text (provided)",
tofile=f"{path} (actual, line {best_start + 1})",
lineterm="",
)
)
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
return f"Error: old_text not found in {path}. No similar text found. Verify the file content." return (
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
)
class ListDirTool(Tool): class ListDirTool(Tool):
@@ -203,24 +191,19 @@ class ListDirTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "list_dir" return "list_dir"
@property @property
def description(self) -> str: def description(self) -> str:
return "List the contents of a directory." return "List the contents of a directory."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "The directory path to list"}},
"path": { "required": ["path"],
"type": "string",
"description": "The directory path to list"
}
},
"required": ["path"]
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, **kwargs: Any) -> str:
try: try:
dir_path = _resolve_path(path, self._workspace, self._allowed_dir) dir_path = _resolve_path(path, self._workspace, self._allowed_dir)

View File

@@ -8,33 +8,33 @@ from nanobot.agent.tools.base import Tool
class ToolRegistry: class ToolRegistry:
""" """
Registry for agent tools. Registry for agent tools.
Allows dynamic registration and execution of tools. Allows dynamic registration and execution of tools.
""" """
def __init__(self): def __init__(self):
self._tools: dict[str, Tool] = {} self._tools: dict[str, Tool] = {}
def register(self, tool: Tool) -> None: def register(self, tool: Tool) -> None:
"""Register a tool.""" """Register a tool."""
self._tools[tool.name] = tool self._tools[tool.name] = tool
def unregister(self, name: str) -> None: def unregister(self, name: str) -> None:
"""Unregister a tool by name.""" """Unregister a tool by name."""
self._tools.pop(name, None) self._tools.pop(name, None)
def get(self, name: str) -> Tool | None: def get(self, name: str) -> Tool | None:
"""Get a tool by name.""" """Get a tool by name."""
return self._tools.get(name) return self._tools.get(name)
def has(self, name: str) -> bool: def has(self, name: str) -> bool:
"""Check if a tool is registered.""" """Check if a tool is registered."""
return name in self._tools return name in self._tools
def get_definitions(self) -> list[dict[str, Any]]: def get_definitions(self) -> list[dict[str, Any]]:
"""Get all tool definitions in OpenAI format.""" """Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()] return [tool.to_schema() for tool in self._tools.values()]
async def execute(self, name: str, params: dict[str, Any]) -> str: async def execute(self, name: str, params: dict[str, Any]) -> str:
"""Execute a tool by name with given parameters.""" """Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]" _HINT = "\n\n[Analyze the error above and try a different approach.]"
@@ -53,14 +53,14 @@ class ToolRegistry:
return result return result
except Exception as e: except Exception as e:
return f"Error executing {name}: {str(e)}" + _HINT return f"Error executing {name}: {str(e)}" + _HINT
@property @property
def tool_names(self) -> list[str]: def tool_names(self) -> list[str]:
"""Get list of registered tool names.""" """Get list of registered tool names."""
return list(self._tools.keys()) return list(self._tools.keys())
def __len__(self) -> int: def __len__(self) -> int:
return len(self._tools) return len(self._tools)
def __contains__(self, name: str) -> bool: def __contains__(self, name: str) -> bool:
return name in self._tools return name in self._tools

View File

@@ -11,7 +11,7 @@ from nanobot.agent.tools.base import Tool
class ExecTool(Tool): class ExecTool(Tool):
"""Tool to execute shell commands.""" """Tool to execute shell commands."""
def __init__( def __init__(
self, self,
timeout: int = 60, timeout: int = 60,
@@ -37,15 +37,15 @@ class ExecTool(Tool):
self.allow_patterns = allow_patterns or [] self.allow_patterns = allow_patterns or []
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self.path_append = path_append self.path_append = path_append
@property @property
def name(self) -> str: def name(self) -> str:
return "exec" return "exec"
@property @property
def description(self) -> str: def description(self) -> str:
return "Execute a shell command and return its output. Use with caution." return "Execute a shell command and return its output. Use with caution."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {

View File

@@ -1,6 +1,6 @@
"""Spawn tool for creating background subagents.""" """Spawn tool for creating background subagents."""
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -10,23 +10,23 @@ if TYPE_CHECKING:
class SpawnTool(Tool): class SpawnTool(Tool):
"""Tool to spawn a subagent for background task execution.""" """Tool to spawn a subagent for background task execution."""
def __init__(self, manager: "SubagentManager"): def __init__(self, manager: "SubagentManager"):
self._manager = manager self._manager = manager
self._origin_channel = "cli" self._origin_channel = "cli"
self._origin_chat_id = "direct" self._origin_chat_id = "direct"
self._session_key = "cli:direct" self._session_key = "cli:direct"
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the origin context for subagent announcements.""" """Set the origin context for subagent announcements."""
self._origin_channel = channel self._origin_channel = channel
self._origin_chat_id = chat_id self._origin_chat_id = chat_id
self._session_key = f"{channel}:{chat_id}" self._session_key = f"{channel}:{chat_id}"
@property @property
def name(self) -> str: def name(self) -> str:
return "spawn" return "spawn"
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
@@ -34,7 +34,7 @@ class SpawnTool(Tool):
"Use this for complex or time-consuming tasks that can run independently. " "Use this for complex or time-consuming tasks that can run independently. "
"The subagent will complete the task and report back when done." "The subagent will complete the task and report back when done."
) )
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
@@ -51,7 +51,7 @@ class SpawnTool(Tool):
}, },
"required": ["task"], "required": ["task"],
} }
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str: async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
"""Spawn a subagent to execute the given task.""" """Spawn a subagent to execute the given task."""
return await self._manager.spawn( return await self._manager.spawn(

View File

@@ -45,7 +45,7 @@ def _validate_url(url: str) -> tuple[bool, str]:
class WebSearchTool(Tool): class WebSearchTool(Tool):
"""Search the web using Brave Search API.""" """Search the web using Brave Search API."""
name = "web_search" name = "web_search"
description = "Search the web. Returns titles, URLs, and snippets." description = "Search the web. Returns titles, URLs, and snippets."
parameters = { parameters = {
@@ -56,7 +56,7 @@ class WebSearchTool(Tool):
}, },
"required": ["query"] "required": ["query"]
} }
def __init__(self, api_key: str | None = None, max_results: int = 5): def __init__(self, api_key: str | None = None, max_results: int = 5):
self._init_api_key = api_key self._init_api_key = api_key
self.max_results = max_results self.max_results = max_results
@@ -73,7 +73,7 @@ class WebSearchTool(Tool):
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey " "Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
"(or export BRAVE_API_KEY), then restart the gateway." "(or export BRAVE_API_KEY), then restart the gateway."
) )
try: try:
n = min(max(count or self.max_results, 1), 10) n = min(max(count or self.max_results, 1), 10)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@@ -84,11 +84,11 @@ class WebSearchTool(Tool):
timeout=10.0 timeout=10.0
) )
r.raise_for_status() r.raise_for_status()
results = r.json().get("web", {}).get("results", []) results = r.json().get("web", {}).get("results", [])
if not results: if not results:
return f"No results for: {query}" return f"No results for: {query}"
lines = [f"Results for: {query}\n"] lines = [f"Results for: {query}\n"]
for i, item in enumerate(results[:n], 1): for i, item in enumerate(results[:n], 1):
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}") lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
@@ -101,7 +101,7 @@ class WebSearchTool(Tool):
class WebFetchTool(Tool): class WebFetchTool(Tool):
"""Fetch and extract content from a URL using Readability.""" """Fetch and extract content from a URL using Readability."""
name = "web_fetch" name = "web_fetch"
description = "Fetch URL and extract readable content (HTML → markdown/text)." description = "Fetch URL and extract readable content (HTML → markdown/text)."
parameters = { parameters = {
@@ -113,10 +113,10 @@ class WebFetchTool(Tool):
}, },
"required": ["url"] "required": ["url"]
} }
def __init__(self, max_chars: int = 50000): def __init__(self, max_chars: int = 50000):
self.max_chars = max_chars self.max_chars = max_chars
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str: async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
from readability import Document from readability import Document
@@ -135,9 +135,9 @@ class WebFetchTool(Tool):
) as client: ) as client:
r = await client.get(url, headers={"User-Agent": USER_AGENT}) r = await client.get(url, headers={"User-Agent": USER_AGENT})
r.raise_for_status() r.raise_for_status()
ctype = r.headers.get("content-type", "") ctype = r.headers.get("content-type", "")
# JSON # JSON
if "application/json" in ctype: if "application/json" in ctype:
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json" text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
@@ -149,16 +149,16 @@ class WebFetchTool(Tool):
extractor = "readability" extractor = "readability"
else: else:
text, extractor = r.text, "raw" text, extractor = r.text, "raw"
truncated = len(text) > max_chars truncated = len(text) > max_chars
if truncated: if truncated:
text = text[:max_chars] text = text[:max_chars]
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code, return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False) "extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
except Exception as e: except Exception as e:
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False) return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
def _to_markdown(self, html: str) -> str: def _to_markdown(self, html: str) -> str:
"""Convert HTML to markdown.""" """Convert HTML to markdown."""
# Convert links, headings, lists before stripping tags # Convert links, headings, lists before stripping tags

View File

@@ -8,7 +8,7 @@ from typing import Any
@dataclass @dataclass
class InboundMessage: class InboundMessage:
"""Message received from a chat channel.""" """Message received from a chat channel."""
channel: str # telegram, discord, slack, whatsapp channel: str # telegram, discord, slack, whatsapp
sender_id: str # User identifier sender_id: str # User identifier
chat_id: str # Chat/channel identifier chat_id: str # Chat/channel identifier
@@ -17,7 +17,7 @@ class InboundMessage:
media: list[str] = field(default_factory=list) # Media URLs media: list[str] = field(default_factory=list) # Media URLs
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
session_key_override: str | None = None # Optional override for thread-scoped sessions session_key_override: str | None = None # Optional override for thread-scoped sessions
@property @property
def session_key(self) -> str: def session_key(self) -> str:
"""Unique key for session identification.""" """Unique key for session identification."""
@@ -27,7 +27,7 @@ class InboundMessage:
@dataclass @dataclass
class OutboundMessage: class OutboundMessage:
"""Message to send to a chat channel.""" """Message to send to a chat channel."""
channel: str channel: str
chat_id: str chat_id: str
content: str content: str

View File

@@ -12,17 +12,17 @@ from nanobot.bus.queue import MessageBus
class BaseChannel(ABC): class BaseChannel(ABC):
""" """
Abstract base class for chat channel implementations. Abstract base class for chat channel implementations.
Each channel (Telegram, Discord, etc.) should implement this interface Each channel (Telegram, Discord, etc.) should implement this interface
to integrate with the nanobot message bus. to integrate with the nanobot message bus.
""" """
name: str = "base" name: str = "base"
def __init__(self, config: Any, bus: MessageBus): def __init__(self, config: Any, bus: MessageBus):
""" """
Initialize the channel. Initialize the channel.
Args: Args:
config: Channel-specific configuration. config: Channel-specific configuration.
bus: The message bus for communication. bus: The message bus for communication.
@@ -30,50 +30,50 @@ class BaseChannel(ABC):
self.config = config self.config = config
self.bus = bus self.bus = bus
self._running = False self._running = False
@abstractmethod @abstractmethod
async def start(self) -> None: async def start(self) -> None:
""" """
Start the channel and begin listening for messages. Start the channel and begin listening for messages.
This should be a long-running async task that: This should be a long-running async task that:
1. Connects to the chat platform 1. Connects to the chat platform
2. Listens for incoming messages 2. Listens for incoming messages
3. Forwards messages to the bus via _handle_message() 3. Forwards messages to the bus via _handle_message()
""" """
pass pass
@abstractmethod @abstractmethod
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the channel and clean up resources.""" """Stop the channel and clean up resources."""
pass pass
@abstractmethod @abstractmethod
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
""" """
Send a message through this channel. Send a message through this channel.
Args: Args:
msg: The message to send. msg: The message to send.
""" """
pass pass
def is_allowed(self, sender_id: str) -> bool: def is_allowed(self, sender_id: str) -> bool:
""" """
Check if a sender is allowed to use this bot. Check if a sender is allowed to use this bot.
Args: Args:
sender_id: The sender's identifier. sender_id: The sender's identifier.
Returns: Returns:
True if allowed, False otherwise. True if allowed, False otherwise.
""" """
allow_list = getattr(self.config, "allow_from", []) allow_list = getattr(self.config, "allow_from", [])
# If no allow list, allow everyone # If no allow list, allow everyone
if not allow_list: if not allow_list:
return True return True
sender_str = str(sender_id) sender_str = str(sender_id)
if sender_str in allow_list: if sender_str in allow_list:
return True return True
@@ -82,7 +82,7 @@ class BaseChannel(ABC):
if part and part in allow_list: if part and part in allow_list:
return True return True
return False return False
async def _handle_message( async def _handle_message(
self, self,
sender_id: str, sender_id: str,
@@ -94,9 +94,9 @@ class BaseChannel(ABC):
) -> None: ) -> None:
""" """
Handle an incoming message from the chat platform. Handle an incoming message from the chat platform.
This method checks permissions and forwards to the bus. This method checks permissions and forwards to the bus.
Args: Args:
sender_id: The sender's identifier. sender_id: The sender's identifier.
chat_id: The chat/channel identifier. chat_id: The chat/channel identifier.
@@ -112,7 +112,7 @@ class BaseChannel(ABC):
sender_id, self.name, sender_id, self.name,
) )
return return
msg = InboundMessage( msg = InboundMessage(
channel=self.name, channel=self.name,
sender_id=str(sender_id), sender_id=str(sender_id),
@@ -122,9 +122,9 @@ class BaseChannel(ABC):
metadata=metadata or {}, metadata=metadata or {},
session_key_override=session_key, session_key_override=session_key,
) )
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
"""Check if the channel is running.""" """Check if the channel is running."""

View File

@@ -5,8 +5,8 @@ import json
import time import time
from typing import Any from typing import Any
from loguru import logger
import httpx import httpx
from loguru import logger
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -15,11 +15,11 @@ from nanobot.config.schema import DingTalkConfig
try: try:
from dingtalk_stream import ( from dingtalk_stream import (
DingTalkStreamClient, AckMessage,
Credential,
CallbackHandler, CallbackHandler,
CallbackMessage, CallbackMessage,
AckMessage, Credential,
DingTalkStreamClient,
) )
from dingtalk_stream.chatbot import ChatbotMessage from dingtalk_stream.chatbot import ChatbotMessage

View File

@@ -14,7 +14,6 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import DiscordConfig from nanobot.config.schema import DiscordConfig
DISCORD_API_BASE = "https://discord.com/api/v10" DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit MAX_MESSAGE_LEN = 2000 # Discord message character limit

View File

@@ -23,12 +23,11 @@ try:
CreateFileRequestBody, CreateFileRequestBody,
CreateImageRequest, CreateImageRequest,
CreateImageRequestBody, CreateImageRequestBody,
CreateMessageRequest,
CreateMessageRequestBody,
CreateMessageReactionRequest, CreateMessageReactionRequest,
CreateMessageReactionRequestBody, CreateMessageReactionRequestBody,
CreateMessageRequest,
CreateMessageRequestBody,
Emoji, Emoji,
GetFileRequest,
GetMessageResourceRequest, GetMessageResourceRequest,
P2ImMessageReceiveV1, P2ImMessageReceiveV1,
) )
@@ -70,7 +69,7 @@ def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
def _extract_interactive_content(content: dict) -> list[str]: def _extract_interactive_content(content: dict) -> list[str]:
"""Recursively extract text and links from interactive card content.""" """Recursively extract text and links from interactive card content."""
parts = [] parts = []
if isinstance(content, str): if isinstance(content, str):
try: try:
content = json.loads(content) content = json.loads(content)
@@ -104,19 +103,19 @@ def _extract_interactive_content(content: dict) -> list[str]:
header_text = header_title.get("content", "") or header_title.get("text", "") header_text = header_title.get("content", "") or header_title.get("text", "")
if header_text: if header_text:
parts.append(f"title: {header_text}") parts.append(f"title: {header_text}")
return parts return parts
def _extract_element_content(element: dict) -> list[str]: def _extract_element_content(element: dict) -> list[str]:
"""Extract content from a single card element.""" """Extract content from a single card element."""
parts = [] parts = []
if not isinstance(element, dict): if not isinstance(element, dict):
return parts return parts
tag = element.get("tag", "") tag = element.get("tag", "")
if tag in ("markdown", "lark_md"): if tag in ("markdown", "lark_md"):
content = element.get("content", "") content = element.get("content", "")
if content: if content:
@@ -177,17 +176,17 @@ def _extract_element_content(element: dict) -> list[str]:
else: else:
for ne in element.get("elements", []): for ne in element.get("elements", []):
parts.extend(_extract_element_content(ne)) parts.extend(_extract_element_content(ne))
return parts return parts
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]: def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
"""Extract text and image keys from Feishu post (rich text) message content. """Extract text and image keys from Feishu post (rich text) message content.
Supports two formats: Supports two formats:
1. Direct format: {"title": "...", "content": [...]} 1. Direct format: {"title": "...", "content": [...]}
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}} 2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
Returns: Returns:
(text, image_keys) - extracted text and list of image keys (text, image_keys) - extracted text and list of image keys
""" """
@@ -220,26 +219,26 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
image_keys.append(img_key) image_keys.append(img_key)
text = " ".join(text_parts).strip() if text_parts else None text = " ".join(text_parts).strip() if text_parts else None
return text, image_keys return text, image_keys
# Try direct format first # Try direct format first
if "content" in content_json: if "content" in content_json:
text, images = extract_from_lang(content_json) text, images = extract_from_lang(content_json)
if text or images: if text or images:
return text or "", images return text or "", images
# Try localized format # Try localized format
for lang_key in ("zh_cn", "en_us", "ja_jp"): for lang_key in ("zh_cn", "en_us", "ja_jp"):
lang_content = content_json.get(lang_key) lang_content = content_json.get(lang_key)
text, images = extract_from_lang(lang_content) text, images = extract_from_lang(lang_content)
if text or images: if text or images:
return text or "", images return text or "", images
return "", [] return "", []
def _extract_post_text(content_json: dict) -> str: def _extract_post_text(content_json: dict) -> str:
"""Extract plain text from Feishu post (rich text) message content. """Extract plain text from Feishu post (rich text) message content.
Legacy wrapper for _extract_post_content, returns only text. Legacy wrapper for _extract_post_content, returns only text.
""" """
text, _ = _extract_post_content(content_json) text, _ = _extract_post_content(content_json)
@@ -249,17 +248,17 @@ def _extract_post_text(content_json: dict) -> str:
class FeishuChannel(BaseChannel): class FeishuChannel(BaseChannel):
""" """
Feishu/Lark channel using WebSocket long connection. Feishu/Lark channel using WebSocket long connection.
Uses WebSocket to receive events - no public IP or webhook required. Uses WebSocket to receive events - no public IP or webhook required.
Requires: Requires:
- App ID and App Secret from Feishu Open Platform - App ID and App Secret from Feishu Open Platform
- Bot capability enabled - Bot capability enabled
- Event subscription enabled (im.message.receive_v1) - Event subscription enabled (im.message.receive_v1)
""" """
name = "feishu" name = "feishu"
def __init__(self, config: FeishuConfig, bus: MessageBus): def __init__(self, config: FeishuConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: FeishuConfig = config self.config: FeishuConfig = config
@@ -268,27 +267,27 @@ class FeishuChannel(BaseChannel):
self._ws_thread: threading.Thread | None = None self._ws_thread: threading.Thread | None = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
self._loop: asyncio.AbstractEventLoop | None = None self._loop: asyncio.AbstractEventLoop | None = None
async def start(self) -> None: async def start(self) -> None:
"""Start the Feishu bot with WebSocket long connection.""" """Start the Feishu bot with WebSocket long connection."""
if not FEISHU_AVAILABLE: if not FEISHU_AVAILABLE:
logger.error("Feishu SDK not installed. Run: pip install lark-oapi") logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
return return
if not self.config.app_id or not self.config.app_secret: if not self.config.app_id or not self.config.app_secret:
logger.error("Feishu app_id and app_secret not configured") logger.error("Feishu app_id and app_secret not configured")
return return
self._running = True self._running = True
self._loop = asyncio.get_running_loop() self._loop = asyncio.get_running_loop()
# Create Lark client for sending messages # Create Lark client for sending messages
self._client = lark.Client.builder() \ self._client = lark.Client.builder() \
.app_id(self.config.app_id) \ .app_id(self.config.app_id) \
.app_secret(self.config.app_secret) \ .app_secret(self.config.app_secret) \
.log_level(lark.LogLevel.INFO) \ .log_level(lark.LogLevel.INFO) \
.build() .build()
# Create event handler (only register message receive, ignore other events) # Create event handler (only register message receive, ignore other events)
event_handler = lark.EventDispatcherHandler.builder( event_handler = lark.EventDispatcherHandler.builder(
self.config.encrypt_key or "", self.config.encrypt_key or "",
@@ -296,7 +295,7 @@ class FeishuChannel(BaseChannel):
).register_p2_im_message_receive_v1( ).register_p2_im_message_receive_v1(
self._on_message_sync self._on_message_sync
).build() ).build()
# Create WebSocket client for long connection # Create WebSocket client for long connection
self._ws_client = lark.ws.Client( self._ws_client = lark.ws.Client(
self.config.app_id, self.config.app_id,
@@ -304,7 +303,7 @@ class FeishuChannel(BaseChannel):
event_handler=event_handler, event_handler=event_handler,
log_level=lark.LogLevel.INFO log_level=lark.LogLevel.INFO
) )
# Start WebSocket client in a separate thread with reconnect loop # Start WebSocket client in a separate thread with reconnect loop
def run_ws(): def run_ws():
while self._running: while self._running:
@@ -313,18 +312,19 @@ class FeishuChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.warning("Feishu WebSocket error: {}", e) logger.warning("Feishu WebSocket error: {}", e)
if self._running: if self._running:
import time; time.sleep(5) import time
time.sleep(5)
self._ws_thread = threading.Thread(target=run_ws, daemon=True) self._ws_thread = threading.Thread(target=run_ws, daemon=True)
self._ws_thread.start() self._ws_thread.start()
logger.info("Feishu bot started with WebSocket long connection") logger.info("Feishu bot started with WebSocket long connection")
logger.info("No public IP required - using WebSocket to receive events") logger.info("No public IP required - using WebSocket to receive events")
# Keep running until stopped # Keep running until stopped
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the Feishu bot.""" """Stop the Feishu bot."""
self._running = False self._running = False
@@ -334,7 +334,7 @@ class FeishuChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.warning("Error stopping WebSocket client: {}", e) logger.warning("Error stopping WebSocket client: {}", e)
logger.info("Feishu bot stopped") logger.info("Feishu bot stopped")
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool).""" """Sync helper for adding reaction (runs in thread pool)."""
try: try:
@@ -345,9 +345,9 @@ class FeishuChannel(BaseChannel):
.reaction_type(Emoji.builder().emoji_type(emoji_type).build()) .reaction_type(Emoji.builder().emoji_type(emoji_type).build())
.build() .build()
).build() ).build()
response = self._client.im.v1.message_reaction.create(request) response = self._client.im.v1.message_reaction.create(request)
if not response.success(): if not response.success():
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg) logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
else: else:
@@ -358,15 +358,15 @@ class FeishuChannel(BaseChannel):
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None: async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
""" """
Add a reaction emoji to a message (non-blocking). Add a reaction emoji to a message (non-blocking).
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
""" """
if not self._client or not Emoji: if not self._client or not Emoji:
return return
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type) await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
# Regex to match markdown tables (header + separator + data rows) # Regex to match markdown tables (header + separator + data rows)
_TABLE_RE = re.compile( _TABLE_RE = re.compile(
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)", r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
@@ -380,12 +380,13 @@ class FeishuChannel(BaseChannel):
@staticmethod @staticmethod
def _parse_md_table(table_text: str) -> dict | None: def _parse_md_table(table_text: str) -> dict | None:
"""Parse a markdown table into a Feishu table element.""" """Parse a markdown table into a Feishu table element."""
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()] lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
if len(lines) < 3: if len(lines) < 3:
return None return None
split = lambda l: [c.strip() for c in l.strip("|").split("|")] def split(_line: str) -> list[str]:
return [c.strip() for c in _line.strip("|").split("|")]
headers = split(lines[0]) headers = split(lines[0])
rows = [split(l) for l in lines[2:]] rows = [split(_line) for _line in lines[2:]]
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"} columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
for i, h in enumerate(headers)] for i, h in enumerate(headers)]
return { return {
@@ -657,7 +658,7 @@ class FeishuChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.error("Error sending Feishu message: {}", e) logger.error("Error sending Feishu message: {}", e)
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None: def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
""" """
Sync handler for incoming messages (called from WebSocket thread). Sync handler for incoming messages (called from WebSocket thread).
@@ -665,7 +666,7 @@ class FeishuChannel(BaseChannel):
""" """
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop) asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
"""Handle incoming message from Feishu.""" """Handle incoming message from Feishu."""
try: try:

View File

@@ -16,24 +16,24 @@ from nanobot.config.schema import Config
class ChannelManager: class ChannelManager:
""" """
Manages chat channels and coordinates message routing. Manages chat channels and coordinates message routing.
Responsibilities: Responsibilities:
- Initialize enabled channels (Telegram, WhatsApp, etc.) - Initialize enabled channels (Telegram, WhatsApp, etc.)
- Start/stop channels - Start/stop channels
- Route outbound messages - Route outbound messages
""" """
def __init__(self, config: Config, bus: MessageBus): def __init__(self, config: Config, bus: MessageBus):
self.config = config self.config = config
self.bus = bus self.bus = bus
self.channels: dict[str, BaseChannel] = {} self.channels: dict[str, BaseChannel] = {}
self._dispatch_task: asyncio.Task | None = None self._dispatch_task: asyncio.Task | None = None
self._init_channels() self._init_channels()
def _init_channels(self) -> None: def _init_channels(self) -> None:
"""Initialize channels based on config.""" """Initialize channels based on config."""
# Telegram channel # Telegram channel
if self.config.channels.telegram.enabled: if self.config.channels.telegram.enabled:
try: try:
@@ -46,7 +46,7 @@ class ChannelManager:
logger.info("Telegram channel enabled") logger.info("Telegram channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Telegram channel not available: {}", e) logger.warning("Telegram channel not available: {}", e)
# WhatsApp channel # WhatsApp channel
if self.config.channels.whatsapp.enabled: if self.config.channels.whatsapp.enabled:
try: try:
@@ -68,7 +68,7 @@ class ChannelManager:
logger.info("Discord channel enabled") logger.info("Discord channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Discord channel not available: {}", e) logger.warning("Discord channel not available: {}", e)
# Feishu channel # Feishu channel
if self.config.channels.feishu.enabled: if self.config.channels.feishu.enabled:
try: try:
@@ -136,7 +136,7 @@ class ChannelManager:
logger.info("QQ channel enabled") logger.info("QQ channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("QQ channel not available: {}", e) logger.warning("QQ channel not available: {}", e)
# Matrix channel # Matrix channel
if self.config.channels.matrix.enabled: if self.config.channels.matrix.enabled:
try: try:
@@ -148,7 +148,7 @@ class ChannelManager:
logger.info("Matrix channel enabled") logger.info("Matrix channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Matrix channel not available: {}", e) logger.warning("Matrix channel not available: {}", e)
async def _start_channel(self, name: str, channel: BaseChannel) -> None: async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions.""" """Start a channel and log any exceptions."""
try: try:
@@ -161,23 +161,23 @@ class ChannelManager:
if not self.channels: if not self.channels:
logger.warning("No channels enabled") logger.warning("No channels enabled")
return return
# Start outbound dispatcher # Start outbound dispatcher
self._dispatch_task = asyncio.create_task(self._dispatch_outbound()) self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
# Start channels # Start channels
tasks = [] tasks = []
for name, channel in self.channels.items(): for name, channel in self.channels.items():
logger.info("Starting {} channel...", name) logger.info("Starting {} channel...", name)
tasks.append(asyncio.create_task(self._start_channel(name, channel))) tasks.append(asyncio.create_task(self._start_channel(name, channel)))
# Wait for all to complete (they should run forever) # Wait for all to complete (they should run forever)
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
async def stop_all(self) -> None: async def stop_all(self) -> None:
"""Stop all channels and the dispatcher.""" """Stop all channels and the dispatcher."""
logger.info("Stopping all channels...") logger.info("Stopping all channels...")
# Stop dispatcher # Stop dispatcher
if self._dispatch_task: if self._dispatch_task:
self._dispatch_task.cancel() self._dispatch_task.cancel()
@@ -185,7 +185,7 @@ class ChannelManager:
await self._dispatch_task await self._dispatch_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Stop all channels # Stop all channels
for name, channel in self.channels.items(): for name, channel in self.channels.items():
try: try:
@@ -193,24 +193,24 @@ class ChannelManager:
logger.info("Stopped {} channel", name) logger.info("Stopped {} channel", name)
except Exception as e: except Exception as e:
logger.error("Error stopping {}: {}", name, e) logger.error("Error stopping {}: {}", name, e)
async def _dispatch_outbound(self) -> None: async def _dispatch_outbound(self) -> None:
"""Dispatch outbound messages to the appropriate channel.""" """Dispatch outbound messages to the appropriate channel."""
logger.info("Outbound dispatcher started") logger.info("Outbound dispatcher started")
while True: while True:
try: try:
msg = await asyncio.wait_for( msg = await asyncio.wait_for(
self.bus.consume_outbound(), self.bus.consume_outbound(),
timeout=1.0 timeout=1.0
) )
if msg.metadata.get("_progress"): if msg.metadata.get("_progress"):
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints: if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
continue continue
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress: if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
continue continue
channel = self.channels.get(msg.channel) channel = self.channels.get(msg.channel)
if channel: if channel:
try: try:
@@ -219,16 +219,16 @@ class ChannelManager:
logger.error("Error sending to {}: {}", msg.channel, e) logger.error("Error sending to {}: {}", msg.channel, e)
else: else:
logger.warning("Unknown channel: {}", msg.channel) logger.warning("Unknown channel: {}", msg.channel)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except asyncio.CancelledError: except asyncio.CancelledError:
break break
def get_channel(self, name: str) -> BaseChannel | None: def get_channel(self, name: str) -> BaseChannel | None:
"""Get a channel by name.""" """Get a channel by name."""
return self.channels.get(name) return self.channels.get(name)
def get_status(self) -> dict[str, Any]: def get_status(self) -> dict[str, Any]:
"""Get status of all channels.""" """Get status of all channels."""
return { return {
@@ -238,7 +238,7 @@ class ChannelManager:
} }
for name, channel in self.channels.items() for name, channel in self.channels.items()
} }
@property @property
def enabled_channels(self) -> list[str]: def enabled_channels(self) -> list[str]:
"""Get list of enabled channel names.""" """Get list of enabled channel names."""

View File

@@ -12,10 +12,22 @@ try:
import nh3 import nh3
from mistune import create_markdown from mistune import create_markdown
from nio import ( from nio import (
AsyncClient, AsyncClientConfig, ContentRepositoryConfigError, AsyncClient,
DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse, AsyncClientConfig,
RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText, ContentRepositoryConfigError,
RoomSendError, RoomTypingError, SyncError, UploadError, DownloadError,
InviteEvent,
JoinError,
MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedMedia,
RoomMessage,
RoomMessageMedia,
RoomMessageText,
RoomSendError,
RoomTypingError,
SyncError,
UploadError,
) )
from nio.crypto.attachments import decrypt_attachment from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError from nio.exceptions import EncryptionError

View File

@@ -5,11 +5,10 @@ import re
from typing import Any from typing import Any
from loguru import logger from loguru import logger
from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
from slackify_markdown import slackify_markdown from slackify_markdown import slackify_markdown
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
from loguru import logger from loguru import logger
from telegram import BotCommand, Update, ReplyParameters from telegram import BotCommand, ReplyParameters, Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
@@ -21,60 +22,60 @@ def _markdown_to_telegram_html(text: str) -> str:
""" """
if not text: if not text:
return "" return ""
# 1. Extract and protect code blocks (preserve content from other processing) # 1. Extract and protect code blocks (preserve content from other processing)
code_blocks: list[str] = [] code_blocks: list[str] = []
def save_code_block(m: re.Match) -> str: def save_code_block(m: re.Match) -> str:
code_blocks.append(m.group(1)) code_blocks.append(m.group(1))
return f"\x00CB{len(code_blocks) - 1}\x00" return f"\x00CB{len(code_blocks) - 1}\x00"
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text) text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 2. Extract and protect inline code # 2. Extract and protect inline code
inline_codes: list[str] = [] inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str: def save_inline_code(m: re.Match) -> str:
inline_codes.append(m.group(1)) inline_codes.append(m.group(1))
return f"\x00IC{len(inline_codes) - 1}\x00" return f"\x00IC{len(inline_codes) - 1}\x00"
text = re.sub(r'`([^`]+)`', save_inline_code, text) text = re.sub(r'`([^`]+)`', save_inline_code, text)
# 3. Headers # Title -> just the title text # 3. Headers # Title -> just the title text
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE) text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
# 4. Blockquotes > text -> just the text (before HTML escaping) # 4. Blockquotes > text -> just the text (before HTML escaping)
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE) text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
# 5. Escape HTML special characters # 5. Escape HTML special characters
text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
# 6. Links [text](url) - must be before bold/italic to handle nested cases # 6. Links [text](url) - must be before bold/italic to handle nested cases
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text) text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
# 7. Bold **text** or __text__ # 7. Bold **text** or __text__
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text) text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text) text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
# 8. Italic _text_ (avoid matching inside words like some_var_name) # 8. Italic _text_ (avoid matching inside words like some_var_name)
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text) text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
# 9. Strikethrough ~~text~~ # 9. Strikethrough ~~text~~
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text) text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
# 10. Bullet lists - item -> • item # 10. Bullet lists - item -> • item
text = re.sub(r'^[-*]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^[-*]\s+', '', text, flags=re.MULTILINE)
# 11. Restore inline code with HTML tags # 11. Restore inline code with HTML tags
for i, code in enumerate(inline_codes): for i, code in enumerate(inline_codes):
# Escape HTML in code content # Escape HTML in code content
escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>") text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
# 12. Restore code blocks with HTML tags # 12. Restore code blocks with HTML tags
for i, code in enumerate(code_blocks): for i, code in enumerate(code_blocks):
# Escape HTML in code content # Escape HTML in code content
escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>") text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
return text return text
@@ -101,12 +102,12 @@ def _split_message(content: str, max_len: int = 4000) -> list[str]:
class TelegramChannel(BaseChannel): class TelegramChannel(BaseChannel):
""" """
Telegram channel using long polling. Telegram channel using long polling.
Simple and reliable - no webhook/public IP needed. Simple and reliable - no webhook/public IP needed.
""" """
name = "telegram" name = "telegram"
# Commands registered with Telegram's command menu # Commands registered with Telegram's command menu
BOT_COMMANDS = [ BOT_COMMANDS = [
BotCommand("start", "Start the bot"), BotCommand("start", "Start the bot"),
@@ -114,7 +115,7 @@ class TelegramChannel(BaseChannel):
BotCommand("stop", "Stop the current task"), BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
] ]
def __init__( def __init__(
self, self,
config: TelegramConfig, config: TelegramConfig,
@@ -129,15 +130,15 @@ class TelegramChannel(BaseChannel):
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {} self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {} self._media_group_tasks: dict[str, asyncio.Task] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
if not self.config.token: if not self.config.token:
logger.error("Telegram bot token not configured") logger.error("Telegram bot token not configured")
return return
self._running = True self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs # Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
@@ -145,51 +146,51 @@ class TelegramChannel(BaseChannel):
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy) builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build() self._app = builder.build()
self._app.add_error_handler(self._on_error) self._app.add_error_handler(self._on_error)
# Add command handlers # Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command)) self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help)) self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents # Add message handler for text, photos, voice, documents
self._app.add_handler( self._app.add_handler(
MessageHandler( MessageHandler(
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL) (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
& ~filters.COMMAND, & ~filters.COMMAND,
self._on_message self._on_message
) )
) )
logger.info("Starting Telegram bot (polling mode)...") logger.info("Starting Telegram bot (polling mode)...")
# Initialize and start polling # Initialize and start polling
await self._app.initialize() await self._app.initialize()
await self._app.start() await self._app.start()
# Get bot info and register command menu # Get bot info and register command menu
bot_info = await self._app.bot.get_me() bot_info = await self._app.bot.get_me()
logger.info("Telegram bot @{} connected", bot_info.username) logger.info("Telegram bot @{} connected", bot_info.username)
try: try:
await self._app.bot.set_my_commands(self.BOT_COMMANDS) await self._app.bot.set_my_commands(self.BOT_COMMANDS)
logger.debug("Telegram bot commands registered") logger.debug("Telegram bot commands registered")
except Exception as e: except Exception as e:
logger.warning("Failed to register bot commands: {}", e) logger.warning("Failed to register bot commands: {}", e)
# Start polling (this runs until stopped) # Start polling (this runs until stopped)
await self._app.updater.start_polling( await self._app.updater.start_polling(
allowed_updates=["message"], allowed_updates=["message"],
drop_pending_updates=True # Ignore old messages on startup drop_pending_updates=True # Ignore old messages on startup
) )
# Keep running until stopped # Keep running until stopped
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the Telegram bot.""" """Stop the Telegram bot."""
self._running = False self._running = False
# Cancel all typing indicators # Cancel all typing indicators
for chat_id in list(self._typing_tasks): for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id) self._stop_typing(chat_id)
@@ -198,14 +199,14 @@ class TelegramChannel(BaseChannel):
task.cancel() task.cancel()
self._media_group_tasks.clear() self._media_group_tasks.clear()
self._media_group_buffers.clear() self._media_group_buffers.clear()
if self._app: if self._app:
logger.info("Stopping Telegram bot...") logger.info("Stopping Telegram bot...")
await self._app.updater.stop() await self._app.updater.stop()
await self._app.stop() await self._app.stop()
await self._app.shutdown() await self._app.shutdown()
self._app = None self._app = None
@staticmethod @staticmethod
def _get_media_type(path: str) -> str: def _get_media_type(path: str) -> str:
"""Guess media type from file extension.""" """Guess media type from file extension."""
@@ -253,7 +254,7 @@ class TelegramChannel(BaseChannel):
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document" param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
with open(media_path, 'rb') as f: with open(media_path, 'rb') as f:
await sender( await sender(
chat_id=chat_id, chat_id=chat_id,
**{param: f}, **{param: f},
reply_parameters=reply_params reply_parameters=reply_params
) )
@@ -272,8 +273,8 @@ class TelegramChannel(BaseChannel):
try: try:
html = _markdown_to_telegram_html(chunk) html = _markdown_to_telegram_html(chunk)
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=html, text=html,
parse_mode="HTML", parse_mode="HTML",
reply_parameters=reply_params reply_parameters=reply_params
) )
@@ -281,13 +282,13 @@ class TelegramChannel(BaseChannel):
logger.warning("HTML parse failed, falling back to plain text: {}", e) logger.warning("HTML parse failed, falling back to plain text: {}", e)
try: try:
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=chunk, text=chunk,
reply_parameters=reply_params reply_parameters=reply_params
) )
except Exception as e2: except Exception as e2:
logger.error("Error sending Telegram message: {}", e2) logger.error("Error sending Telegram message: {}", e2)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
@@ -326,34 +327,34 @@ class TelegramChannel(BaseChannel):
chat_id=str(update.message.chat_id), chat_id=str(update.message.chat_id),
content=update.message.text, content=update.message.text,
) )
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming messages (text, photos, voice, documents).""" """Handle incoming messages (text, photos, voice, documents)."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
return return
message = update.message message = update.message
user = update.effective_user user = update.effective_user
chat_id = message.chat_id chat_id = message.chat_id
sender_id = self._sender_id(user) sender_id = self._sender_id(user)
# Store chat_id for replies # Store chat_id for replies
self._chat_ids[sender_id] = chat_id self._chat_ids[sender_id] = chat_id
# Build content from text and/or media # Build content from text and/or media
content_parts = [] content_parts = []
media_paths = [] media_paths = []
# Text content # Text content
if message.text: if message.text:
content_parts.append(message.text) content_parts.append(message.text)
if message.caption: if message.caption:
content_parts.append(message.caption) content_parts.append(message.caption)
# Handle media files # Handle media files
media_file = None media_file = None
media_type = None media_type = None
if message.photo: if message.photo:
media_file = message.photo[-1] # Largest photo media_file = message.photo[-1] # Largest photo
media_type = "image" media_type = "image"
@@ -366,23 +367,23 @@ class TelegramChannel(BaseChannel):
elif message.document: elif message.document:
media_file = message.document media_file = message.document
media_type = "file" media_type = "file"
# Download media if present # Download media if present
if media_file and self._app: if media_file and self._app:
try: try:
file = await self._app.bot.get_file(media_file.file_id) file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
# Save to workspace/media/ # Save to workspace/media/
from pathlib import Path from pathlib import Path
media_dir = Path.home() / ".nanobot" / "media" media_dir = Path.home() / ".nanobot" / "media"
media_dir.mkdir(parents=True, exist_ok=True) media_dir.mkdir(parents=True, exist_ok=True)
file_path = media_dir / f"{media_file.file_id[:16]}{ext}" file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path)) await file.download_to_drive(str(file_path))
media_paths.append(str(file_path)) media_paths.append(str(file_path))
# Handle voice transcription # Handle voice transcription
if media_type == "voice" or media_type == "audio": if media_type == "voice" or media_type == "audio":
from nanobot.providers.transcription import GroqTranscriptionProvider from nanobot.providers.transcription import GroqTranscriptionProvider
@@ -395,16 +396,16 @@ class TelegramChannel(BaseChannel):
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
else: else:
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
logger.debug("Downloaded {} to {}", media_type, file_path) logger.debug("Downloaded {} to {}", media_type, file_path)
except Exception as e: except Exception as e:
logger.error("Failed to download media: {}", e) logger.error("Failed to download media: {}", e)
content_parts.append(f"[{media_type}: download failed]") content_parts.append(f"[{media_type}: download failed]")
content = "\n".join(content_parts) if content_parts else "[empty message]" content = "\n".join(content_parts) if content_parts else "[empty message]"
logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
# Telegram media groups: buffer briefly, forward as one aggregated turn. # Telegram media groups: buffer briefly, forward as one aggregated turn.
@@ -428,10 +429,10 @@ class TelegramChannel(BaseChannel):
if key not in self._media_group_tasks: if key not in self._media_group_tasks:
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key)) self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
return return
# Start typing indicator before processing # Start typing indicator before processing
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
# Forward to the message bus # Forward to the message bus
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
@@ -446,7 +447,7 @@ class TelegramChannel(BaseChannel):
"is_group": message.chat.type != "private" "is_group": message.chat.type != "private"
} }
) )
async def _flush_media_group(self, key: str) -> None: async def _flush_media_group(self, key: str) -> None:
"""Wait briefly, then forward buffered media-group as one turn.""" """Wait briefly, then forward buffered media-group as one turn."""
try: try:
@@ -467,13 +468,13 @@ class TelegramChannel(BaseChannel):
# Cancel any existing typing task for this chat # Cancel any existing typing task for this chat
self._stop_typing(chat_id) self._stop_typing(chat_id)
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
def _stop_typing(self, chat_id: str) -> None: def _stop_typing(self, chat_id: str) -> None:
"""Stop the typing indicator for a chat.""" """Stop the typing indicator for a chat."""
task = self._typing_tasks.pop(chat_id, None) task = self._typing_tasks.pop(chat_id, None)
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
async def _typing_loop(self, chat_id: str) -> None: async def _typing_loop(self, chat_id: str) -> None:
"""Repeatedly send 'typing' action until cancelled.""" """Repeatedly send 'typing' action until cancelled."""
try: try:
@@ -484,7 +485,7 @@ class TelegramChannel(BaseChannel):
pass pass
except Exception as e: except Exception as e:
logger.debug("Typing indicator stopped for {}: {}", chat_id, e) logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Log polling / handler errors instead of silently swallowing them.""" """Log polling / handler errors instead of silently swallowing them."""
logger.error("Telegram error: {}", context.error) logger.error("Telegram error: {}", context.error)
@@ -498,6 +499,6 @@ class TelegramChannel(BaseChannel):
} }
if mime_type in ext_map: if mime_type in ext_map:
return ext_map[mime_type] return ext_map[mime_type]
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
return type_map.get(media_type, "") return type_map.get(media_type, "")

View File

@@ -3,7 +3,6 @@
import asyncio import asyncio
import json import json
from collections import OrderedDict from collections import OrderedDict
from typing import Any
from loguru import logger from loguru import logger
@@ -29,17 +28,17 @@ class WhatsAppChannel(BaseChannel):
self._ws = None self._ws = None
self._connected = False self._connected = False
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
async def start(self) -> None: async def start(self) -> None:
"""Start the WhatsApp channel by connecting to the bridge.""" """Start the WhatsApp channel by connecting to the bridge."""
import websockets import websockets
bridge_url = self.config.bridge_url bridge_url = self.config.bridge_url
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url) logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
self._running = True self._running = True
while self._running: while self._running:
try: try:
async with websockets.connect(bridge_url) as ws: async with websockets.connect(bridge_url) as ws:
@@ -49,40 +48,40 @@ class WhatsAppChannel(BaseChannel):
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token})) await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
self._connected = True self._connected = True
logger.info("Connected to WhatsApp bridge") logger.info("Connected to WhatsApp bridge")
# Listen for messages # Listen for messages
async for message in ws: async for message in ws:
try: try:
await self._handle_bridge_message(message) await self._handle_bridge_message(message)
except Exception as e: except Exception as e:
logger.error("Error handling bridge message: {}", e) logger.error("Error handling bridge message: {}", e)
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
self._connected = False self._connected = False
self._ws = None self._ws = None
logger.warning("WhatsApp bridge connection error: {}", e) logger.warning("WhatsApp bridge connection error: {}", e)
if self._running: if self._running:
logger.info("Reconnecting in 5 seconds...") logger.info("Reconnecting in 5 seconds...")
await asyncio.sleep(5) await asyncio.sleep(5)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the WhatsApp channel.""" """Stop the WhatsApp channel."""
self._running = False self._running = False
self._connected = False self._connected = False
if self._ws: if self._ws:
await self._ws.close() await self._ws.close()
self._ws = None self._ws = None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WhatsApp.""" """Send a message through WhatsApp."""
if not self._ws or not self._connected: if not self._ws or not self._connected:
logger.warning("WhatsApp bridge not connected") logger.warning("WhatsApp bridge not connected")
return return
try: try:
payload = { payload = {
"type": "send", "type": "send",
@@ -92,7 +91,7 @@ class WhatsAppChannel(BaseChannel):
await self._ws.send(json.dumps(payload, ensure_ascii=False)) await self._ws.send(json.dumps(payload, ensure_ascii=False))
except Exception as e: except Exception as e:
logger.error("Error sending WhatsApp message: {}", e) logger.error("Error sending WhatsApp message: {}", e)
async def _handle_bridge_message(self, raw: str) -> None: async def _handle_bridge_message(self, raw: str) -> None:
"""Handle a message from the bridge.""" """Handle a message from the bridge."""
try: try:
@@ -100,9 +99,9 @@ class WhatsAppChannel(BaseChannel):
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning("Invalid JSON from bridge: {}", raw[:100]) logger.warning("Invalid JSON from bridge: {}", raw[:100])
return return
msg_type = data.get("type") msg_type = data.get("type")
if msg_type == "message": if msg_type == "message":
# Incoming message from WhatsApp # Incoming message from WhatsApp
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net # Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
@@ -139,20 +138,20 @@ class WhatsAppChannel(BaseChannel):
"is_group": data.get("isGroup", False) "is_group": data.get("isGroup", False)
} }
) )
elif msg_type == "status": elif msg_type == "status":
# Connection status update # Connection status update
status = data.get("status") status = data.get("status")
logger.info("WhatsApp status: {}", status) logger.info("WhatsApp status: {}", status)
if status == "connected": if status == "connected":
self._connected = True self._connected = True
elif status == "disconnected": elif status == "disconnected":
self._connected = False self._connected = False
elif msg_type == "qr": elif msg_type == "qr":
# QR code for authentication # QR code for authentication
logger.info("Scan QR code in the bridge terminal to connect WhatsApp") logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
elif msg_type == "error": elif msg_type == "error":
logger.error("WhatsApp bridge error: {}", data.get('error')) logger.error("WhatsApp bridge error: {}", data.get('error'))

View File

@@ -2,23 +2,22 @@
import asyncio import asyncio
import os import os
import signal
from pathlib import Path
import select import select
import signal
import sys import sys
from pathlib import Path
import typer import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
from rich.text import Text from rich.text import Text
from prompt_toolkit import PromptSession from nanobot import __logo__, __version__
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from nanobot import __version__, __logo__
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.utils.helpers import sync_workspace_templates from nanobot.utils.helpers import sync_workspace_templates
@@ -160,9 +159,9 @@ def onboard():
from nanobot.config.loader import get_config_path, load_config, save_config from nanobot.config.loader import get_config_path, load_config, save_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.utils.helpers import get_workspace_path from nanobot.utils.helpers import get_workspace_path
config_path = get_config_path() config_path = get_config_path()
if config_path.exists(): if config_path.exists():
console.print(f"[yellow]Config already exists at {config_path}[/yellow]") console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
@@ -178,16 +177,16 @@ def onboard():
else: else:
save_config(Config()) save_config(Config())
console.print(f"[green]✓[/green] Created config at {config_path}") console.print(f"[green]✓[/green] Created config at {config_path}")
# Create workspace # Create workspace
workspace = get_workspace_path() workspace = get_workspace_path()
if not workspace.exists(): if not workspace.exists():
workspace.mkdir(parents=True, exist_ok=True) workspace.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] Created workspace at {workspace}") console.print(f"[green]✓[/green] Created workspace at {workspace}")
sync_workspace_templates(workspace) sync_workspace_templates(workspace)
console.print(f"\n{__logo__} nanobot is ready!") console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:") console.print("\nNext steps:")
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]") console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
@@ -201,9 +200,9 @@ def onboard():
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.custom_provider import CustomProvider
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.custom_provider import CustomProvider
model = config.agents.defaults.model model = config.agents.defaults.model
provider_name = config.get_provider_name(model) provider_name = config.get_provider_name(model)
@@ -248,31 +247,31 @@ def gateway(
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
): ):
"""Start the nanobot gateway.""" """Start the nanobot gateway."""
from nanobot.config.loader import load_config, get_data_dir
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.session.manager import SessionManager from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.session.manager import SessionManager
if verbose: if verbose:
import logging import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
console.print(f"{__logo__} Starting nanobot gateway on port {port}...") console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
config = load_config() config = load_config()
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path) session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation) # Create cron service first (callback set after agent creation)
cron_store_path = get_data_dir() / "cron" / "jobs.json" cron_store_path = get_data_dir() / "cron" / "jobs.json"
cron = CronService(cron_store_path) cron = CronService(cron_store_path)
# Create agent with cron service # Create agent with cron service
agent = AgentLoop( agent = AgentLoop(
bus=bus, bus=bus,
@@ -291,7 +290,7 @@ def gateway(
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
channels_config=config.channels, channels_config=config.channels,
) )
# Set cron callback (needs agent) # Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None: async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent.""" """Execute a cron job through the agent."""
@@ -310,7 +309,7 @@ def gateway(
)) ))
return response return response
cron.on_job = on_cron_job cron.on_job = on_cron_job
# Create channel manager # Create channel manager
channels = ChannelManager(config, bus) channels = ChannelManager(config, bus)
@@ -364,18 +363,18 @@ def gateway(
interval_s=hb_cfg.interval_s, interval_s=hb_cfg.interval_s,
enabled=hb_cfg.enabled, enabled=hb_cfg.enabled,
) )
if channels.enabled_channels: if channels.enabled_channels:
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}") console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
else: else:
console.print("[yellow]Warning: No channels enabled[/yellow]") console.print("[yellow]Warning: No channels enabled[/yellow]")
cron_status = cron.status() cron_status = cron.status()
if cron_status["jobs"] > 0: if cron_status["jobs"] > 0:
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs") console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s") console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
async def run(): async def run():
try: try:
await cron.start() await cron.start()
@@ -392,7 +391,7 @@ def gateway(
cron.stop() cron.stop()
agent.stop() agent.stop()
await channels.stop_all() await channels.stop_all()
asyncio.run(run()) asyncio.run(run())
@@ -411,15 +410,16 @@ def agent(
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
): ):
"""Interact with the agent directly.""" """Interact with the agent directly."""
from nanobot.config.loader import load_config, get_data_dir
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
from nanobot.cron.service import CronService
from loguru import logger from loguru import logger
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService
config = load_config() config = load_config()
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
@@ -431,7 +431,7 @@ def agent(
logger.enable("nanobot") logger.enable("nanobot")
else: else:
logger.disable("nanobot") logger.disable("nanobot")
agent_loop = AgentLoop( agent_loop = AgentLoop(
bus=bus, bus=bus,
provider=provider, provider=provider,
@@ -448,7 +448,7 @@ def agent(
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
channels_config=config.channels, channels_config=config.channels,
) )
# Show spinner when logs are off (no output to miss); skip when logs are on # Show spinner when logs are off (no output to miss); skip when logs are on
def _thinking_ctx(): def _thinking_ctx():
if logs: if logs:
@@ -624,7 +624,7 @@ def channels_status():
"" if mc.enabled else "", "" if mc.enabled else "",
mc_base mc_base
) )
# Telegram # Telegram
tg = config.channels.telegram tg = config.channels.telegram
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]" tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
@@ -677,57 +677,57 @@ def _get_bridge_dir() -> Path:
"""Get the bridge directory, setting it up if needed.""" """Get the bridge directory, setting it up if needed."""
import shutil import shutil
import subprocess import subprocess
# User's bridge location # User's bridge location
user_bridge = Path.home() / ".nanobot" / "bridge" user_bridge = Path.home() / ".nanobot" / "bridge"
# Check if already built # Check if already built
if (user_bridge / "dist" / "index.js").exists(): if (user_bridge / "dist" / "index.js").exists():
return user_bridge return user_bridge
# Check for npm # Check for npm
if not shutil.which("npm"): if not shutil.which("npm"):
console.print("[red]npm not found. Please install Node.js >= 18.[/red]") console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
raise typer.Exit(1) raise typer.Exit(1)
# Find source bridge: first check package data, then source dir # Find source bridge: first check package data, then source dir
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed) pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev) src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
source = None source = None
if (pkg_bridge / "package.json").exists(): if (pkg_bridge / "package.json").exists():
source = pkg_bridge source = pkg_bridge
elif (src_bridge / "package.json").exists(): elif (src_bridge / "package.json").exists():
source = src_bridge source = src_bridge
if not source: if not source:
console.print("[red]Bridge source not found.[/red]") console.print("[red]Bridge source not found.[/red]")
console.print("Try reinstalling: pip install --force-reinstall nanobot") console.print("Try reinstalling: pip install --force-reinstall nanobot")
raise typer.Exit(1) raise typer.Exit(1)
console.print(f"{__logo__} Setting up bridge...") console.print(f"{__logo__} Setting up bridge...")
# Copy to user directory # Copy to user directory
user_bridge.parent.mkdir(parents=True, exist_ok=True) user_bridge.parent.mkdir(parents=True, exist_ok=True)
if user_bridge.exists(): if user_bridge.exists():
shutil.rmtree(user_bridge) shutil.rmtree(user_bridge)
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist")) shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
# Install and build # Install and build
try: try:
console.print(" Installing dependencies...") console.print(" Installing dependencies...")
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True) subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
console.print(" Building...") console.print(" Building...")
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True) subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
console.print("[green]✓[/green] Bridge ready\n") console.print("[green]✓[/green] Bridge ready\n")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
console.print(f"[red]Build failed: {e}[/red]") console.print(f"[red]Build failed: {e}[/red]")
if e.stderr: if e.stderr:
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]") console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
raise typer.Exit(1) raise typer.Exit(1)
return user_bridge return user_bridge
@@ -735,18 +735,19 @@ def _get_bridge_dir() -> Path:
def channels_login(): def channels_login():
"""Link device via QR code.""" """Link device via QR code."""
import subprocess import subprocess
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
config = load_config() config = load_config()
bridge_dir = _get_bridge_dir() bridge_dir = _get_bridge_dir()
console.print(f"{__logo__} Starting bridge...") console.print(f"{__logo__} Starting bridge...")
console.print("Scan the QR code to connect.\n") console.print("Scan the QR code to connect.\n")
env = {**os.environ} env = {**os.environ}
if config.channels.whatsapp.bridge_token: if config.channels.whatsapp.bridge_token:
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
try: try:
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env) subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@@ -770,23 +771,23 @@ def cron_list(
"""List scheduled jobs.""" """List scheduled jobs."""
from nanobot.config.loader import get_data_dir from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json" store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path) service = CronService(store_path)
jobs = service.list_jobs(include_disabled=all) jobs = service.list_jobs(include_disabled=all)
if not jobs: if not jobs:
console.print("No scheduled jobs.") console.print("No scheduled jobs.")
return return
table = Table(title="Scheduled Jobs") table = Table(title="Scheduled Jobs")
table.add_column("ID", style="cyan") table.add_column("ID", style="cyan")
table.add_column("Name") table.add_column("Name")
table.add_column("Schedule") table.add_column("Schedule")
table.add_column("Status") table.add_column("Status")
table.add_column("Next Run") table.add_column("Next Run")
import time import time
from datetime import datetime as _dt from datetime import datetime as _dt
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@@ -798,7 +799,7 @@ def cron_list(
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "") sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
else: else:
sched = "one-time" sched = "one-time"
# Format next run # Format next run
next_run = "" next_run = ""
if job.state.next_run_at_ms: if job.state.next_run_at_ms:
@@ -808,11 +809,11 @@ def cron_list(
next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M") next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
except Exception: except Exception:
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts)) next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]" status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
table.add_row(job.id, job.name, sched, status, next_run) table.add_row(job.id, job.name, sched, status, next_run)
console.print(table) console.print(table)
@@ -832,7 +833,7 @@ def cron_add(
from nanobot.config.loader import get_data_dir from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronSchedule from nanobot.cron.types import CronSchedule
if tz and not cron_expr: if tz and not cron_expr:
console.print("[red]Error: --tz can only be used with --cron[/red]") console.print("[red]Error: --tz can only be used with --cron[/red]")
raise typer.Exit(1) raise typer.Exit(1)
@@ -849,10 +850,10 @@ def cron_add(
else: else:
console.print("[red]Error: Must specify --every, --cron, or --at[/red]") console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
raise typer.Exit(1) raise typer.Exit(1)
store_path = get_data_dir() / "cron" / "jobs.json" store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path) service = CronService(store_path)
try: try:
job = service.add_job( job = service.add_job(
name=name, name=name,
@@ -876,10 +877,10 @@ def cron_remove(
"""Remove a scheduled job.""" """Remove a scheduled job."""
from nanobot.config.loader import get_data_dir from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json" store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path) service = CronService(store_path)
if service.remove_job(job_id): if service.remove_job(job_id):
console.print(f"[green]✓[/green] Removed job {job_id}") console.print(f"[green]✓[/green] Removed job {job_id}")
else: else:
@@ -894,10 +895,10 @@ def cron_enable(
"""Enable or disable a job.""" """Enable or disable a job."""
from nanobot.config.loader import get_data_dir from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json" store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path) service = CronService(store_path)
job = service.enable_job(job_id, enabled=not disable) job = service.enable_job(job_id, enabled=not disable)
if job: if job:
status = "disabled" if disable else "enabled" status = "disabled" if disable else "enabled"
@@ -913,11 +914,12 @@ def cron_run(
): ):
"""Manually run a job.""" """Manually run a job."""
from loguru import logger from loguru import logger
from nanobot.config.loader import load_config, get_data_dir
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
logger.disable("nanobot") logger.disable("nanobot")
config = load_config() config = load_config()
@@ -975,7 +977,7 @@ def cron_run(
@app.command() @app.command()
def status(): def status():
"""Show nanobot status.""" """Show nanobot status."""
from nanobot.config.loader import load_config, get_config_path from nanobot.config.loader import get_config_path, load_config
config_path = get_config_path() config_path = get_config_path()
config = load_config() config = load_config()
@@ -990,7 +992,7 @@ def status():
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import PROVIDERS
console.print(f"Model: {config.agents.defaults.model}") console.print(f"Model: {config.agents.defaults.model}")
# Check API keys from registry # Check API keys from registry
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(config.providers, spec.name, None) p = getattr(config.providers, spec.name, None)

View File

@@ -1,6 +1,6 @@
"""Configuration module for nanobot.""" """Configuration module for nanobot."""
from nanobot.config.loader import load_config, get_config_path from nanobot.config.loader import get_config_path, load_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
__all__ = ["Config", "load_config", "get_config_path"] __all__ = ["Config", "load_config", "get_config_path"]

View File

@@ -3,7 +3,7 @@
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_camel from pydantic.alias_generators import to_camel
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings

View File

@@ -21,17 +21,18 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
"""Compute next run time in ms.""" """Compute next run time in ms."""
if schedule.kind == "at": if schedule.kind == "at":
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
if schedule.kind == "every": if schedule.kind == "every":
if not schedule.every_ms or schedule.every_ms <= 0: if not schedule.every_ms or schedule.every_ms <= 0:
return None return None
# Next interval from now # Next interval from now
return now_ms + schedule.every_ms return now_ms + schedule.every_ms
if schedule.kind == "cron" and schedule.expr: if schedule.kind == "cron" and schedule.expr:
try: try:
from croniter import croniter
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from croniter import croniter
# Use caller-provided reference time for deterministic scheduling # Use caller-provided reference time for deterministic scheduling
base_time = now_ms / 1000 base_time = now_ms / 1000
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
@@ -41,7 +42,7 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
return int(next_dt.timestamp() * 1000) return int(next_dt.timestamp() * 1000)
except Exception: except Exception:
return None return None
return None return None
@@ -61,7 +62,7 @@ def _validate_schedule_for_add(schedule: CronSchedule) -> None:
class CronService: class CronService:
"""Service for managing and executing scheduled jobs.""" """Service for managing and executing scheduled jobs."""
def __init__( def __init__(
self, self,
store_path: Path, store_path: Path,
@@ -72,12 +73,12 @@ class CronService:
self._store: CronStore | None = None self._store: CronStore | None = None
self._timer_task: asyncio.Task | None = None self._timer_task: asyncio.Task | None = None
self._running = False self._running = False
def _load_store(self) -> CronStore: def _load_store(self) -> CronStore:
"""Load jobs from disk.""" """Load jobs from disk."""
if self._store: if self._store:
return self._store return self._store
if self.store_path.exists(): if self.store_path.exists():
try: try:
data = json.loads(self.store_path.read_text(encoding="utf-8")) data = json.loads(self.store_path.read_text(encoding="utf-8"))
@@ -117,16 +118,16 @@ class CronService:
self._store = CronStore() self._store = CronStore()
else: else:
self._store = CronStore() self._store = CronStore()
return self._store return self._store
def _save_store(self) -> None: def _save_store(self) -> None:
"""Save jobs to disk.""" """Save jobs to disk."""
if not self._store: if not self._store:
return return
self.store_path.parent.mkdir(parents=True, exist_ok=True) self.store_path.parent.mkdir(parents=True, exist_ok=True)
data = { data = {
"version": self._store.version, "version": self._store.version,
"jobs": [ "jobs": [
@@ -161,9 +162,9 @@ class CronService:
for j in self._store.jobs for j in self._store.jobs
] ]
} }
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8") self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
async def start(self) -> None: async def start(self) -> None:
"""Start the cron service.""" """Start the cron service."""
self._running = True self._running = True
@@ -172,14 +173,14 @@ class CronService:
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else [])) logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
def stop(self) -> None: def stop(self) -> None:
"""Stop the cron service.""" """Stop the cron service."""
self._running = False self._running = False
if self._timer_task: if self._timer_task:
self._timer_task.cancel() self._timer_task.cancel()
self._timer_task = None self._timer_task = None
def _recompute_next_runs(self) -> None: def _recompute_next_runs(self) -> None:
"""Recompute next run times for all enabled jobs.""" """Recompute next run times for all enabled jobs."""
if not self._store: if not self._store:
@@ -188,73 +189,73 @@ class CronService:
for job in self._store.jobs: for job in self._store.jobs:
if job.enabled: if job.enabled:
job.state.next_run_at_ms = _compute_next_run(job.schedule, now) job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
def _get_next_wake_ms(self) -> int | None: def _get_next_wake_ms(self) -> int | None:
"""Get the earliest next run time across all jobs.""" """Get the earliest next run time across all jobs."""
if not self._store: if not self._store:
return None return None
times = [j.state.next_run_at_ms for j in self._store.jobs times = [j.state.next_run_at_ms for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms] if j.enabled and j.state.next_run_at_ms]
return min(times) if times else None return min(times) if times else None
def _arm_timer(self) -> None: def _arm_timer(self) -> None:
"""Schedule the next timer tick.""" """Schedule the next timer tick."""
if self._timer_task: if self._timer_task:
self._timer_task.cancel() self._timer_task.cancel()
next_wake = self._get_next_wake_ms() next_wake = self._get_next_wake_ms()
if not next_wake or not self._running: if not next_wake or not self._running:
return return
delay_ms = max(0, next_wake - _now_ms()) delay_ms = max(0, next_wake - _now_ms())
delay_s = delay_ms / 1000 delay_s = delay_ms / 1000
async def tick(): async def tick():
await asyncio.sleep(delay_s) await asyncio.sleep(delay_s)
if self._running: if self._running:
await self._on_timer() await self._on_timer()
self._timer_task = asyncio.create_task(tick()) self._timer_task = asyncio.create_task(tick())
async def _on_timer(self) -> None: async def _on_timer(self) -> None:
"""Handle timer tick - run due jobs.""" """Handle timer tick - run due jobs."""
if not self._store: if not self._store:
return return
now = _now_ms() now = _now_ms()
due_jobs = [ due_jobs = [
j for j in self._store.jobs j for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
] ]
for job in due_jobs: for job in due_jobs:
await self._execute_job(job) await self._execute_job(job)
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
async def _execute_job(self, job: CronJob) -> None: async def _execute_job(self, job: CronJob) -> None:
"""Execute a single job.""" """Execute a single job."""
start_ms = _now_ms() start_ms = _now_ms()
logger.info("Cron: executing job '{}' ({})", job.name, job.id) logger.info("Cron: executing job '{}' ({})", job.name, job.id)
try: try:
response = None response = None
if self.on_job: if self.on_job:
response = await self.on_job(job) response = await self.on_job(job)
job.state.last_status = "ok" job.state.last_status = "ok"
job.state.last_error = None job.state.last_error = None
logger.info("Cron: job '{}' completed", job.name) logger.info("Cron: job '{}' completed", job.name)
except Exception as e: except Exception as e:
job.state.last_status = "error" job.state.last_status = "error"
job.state.last_error = str(e) job.state.last_error = str(e)
logger.error("Cron: job '{}' failed: {}", job.name, e) logger.error("Cron: job '{}' failed: {}", job.name, e)
job.state.last_run_at_ms = start_ms job.state.last_run_at_ms = start_ms
job.updated_at_ms = _now_ms() job.updated_at_ms = _now_ms()
# Handle one-shot jobs # Handle one-shot jobs
if job.schedule.kind == "at": if job.schedule.kind == "at":
if job.delete_after_run: if job.delete_after_run:
@@ -265,15 +266,15 @@ class CronService:
else: else:
# Compute next run # Compute next run
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms()) job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
# ========== Public API ========== # ========== Public API ==========
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]: def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
"""List all jobs.""" """List all jobs."""
store = self._load_store() store = self._load_store()
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled] jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf')) return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
def add_job( def add_job(
self, self,
name: str, name: str,
@@ -288,7 +289,7 @@ class CronService:
store = self._load_store() store = self._load_store()
_validate_schedule_for_add(schedule) _validate_schedule_for_add(schedule)
now = _now_ms() now = _now_ms()
job = CronJob( job = CronJob(
id=str(uuid.uuid4())[:8], id=str(uuid.uuid4())[:8],
name=name, name=name,
@@ -306,28 +307,28 @@ class CronService:
updated_at_ms=now, updated_at_ms=now,
delete_after_run=delete_after_run, delete_after_run=delete_after_run,
) )
store.jobs.append(job) store.jobs.append(job)
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info("Cron: added job '{}' ({})", name, job.id) logger.info("Cron: added job '{}' ({})", name, job.id)
return job return job
def remove_job(self, job_id: str) -> bool: def remove_job(self, job_id: str) -> bool:
"""Remove a job by ID.""" """Remove a job by ID."""
store = self._load_store() store = self._load_store()
before = len(store.jobs) before = len(store.jobs)
store.jobs = [j for j in store.jobs if j.id != job_id] store.jobs = [j for j in store.jobs if j.id != job_id]
removed = len(store.jobs) < before removed = len(store.jobs) < before
if removed: if removed:
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info("Cron: removed job {}", job_id) logger.info("Cron: removed job {}", job_id)
return removed return removed
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None: def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
"""Enable or disable a job.""" """Enable or disable a job."""
store = self._load_store() store = self._load_store()
@@ -343,7 +344,7 @@ class CronService:
self._arm_timer() self._arm_timer()
return job return job
return None return None
async def run_job(self, job_id: str, force: bool = False) -> bool: async def run_job(self, job_id: str, force: bool = False) -> bool:
"""Manually run a job.""" """Manually run a job."""
store = self._load_store() store = self._load_store()
@@ -356,7 +357,7 @@ class CronService:
self._arm_timer() self._arm_timer()
return True return True
return False return False
def status(self) -> dict: def status(self) -> dict:
"""Get service status.""" """Get service status."""
store = self._load_store() store = self._load_store()

View File

@@ -21,7 +21,7 @@ class LLMResponse:
finish_reason: str = "stop" finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
@property @property
def has_tool_calls(self) -> bool: def has_tool_calls(self) -> bool:
"""Check if response contains tool calls.""" """Check if response contains tool calls."""
@@ -35,7 +35,7 @@ class LLMProvider(ABC):
Implementations should handle the specifics of each provider's API Implementations should handle the specifics of each provider's API
while maintaining a consistent interface. while maintaining a consistent interface.
""" """
def __init__(self, api_key: str | None = None, api_base: str | None = None): def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
@@ -79,7 +79,7 @@ class LLMProvider(ABC):
result.append(msg) result.append(msg)
return result return result
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
@@ -103,7 +103,7 @@ class LLMProvider(ABC):
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
""" """
pass pass
@abstractmethod @abstractmethod
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model for this provider.""" """Get the default model for this provider."""

View File

@@ -1,19 +1,17 @@
"""LiteLLM provider implementation for multi-provider support.""" """LiteLLM provider implementation for multi-provider support."""
import json
import json_repair
import os import os
import secrets import secrets
import string import string
from typing import Any from typing import Any
import json_repair
import litellm import litellm
from litellm import acompletion from litellm import acompletion
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway from nanobot.providers.registry import find_by_model, find_gateway
# Standard OpenAI chat-completion message keys plus reasoning_content for # Standard OpenAI chat-completion message keys plus reasoning_content for
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.). # thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) _ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
@@ -32,10 +30,10 @@ class LiteLLMProvider(LLMProvider):
a unified interface. Provider-specific logic is driven by the registry a unified interface. Provider-specific logic is driven by the registry
(see providers/registry.py) — no if-elif chains needed here. (see providers/registry.py) — no if-elif chains needed here.
""" """
def __init__( def __init__(
self, self,
api_key: str | None = None, api_key: str | None = None,
api_base: str | None = None, api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5", default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None,
@@ -44,24 +42,24 @@ class LiteLLMProvider(LLMProvider):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
self.extra_headers = extra_headers or {} self.extra_headers = extra_headers or {}
# Detect gateway / local deployment. # Detect gateway / local deployment.
# provider_name (from config key) is the primary signal; # provider_name (from config key) is the primary signal;
# api_key / api_base are fallback for auto-detection. # api_key / api_base are fallback for auto-detection.
self._gateway = find_gateway(provider_name, api_key, api_base) self._gateway = find_gateway(provider_name, api_key, api_base)
# Configure environment variables # Configure environment variables
if api_key: if api_key:
self._setup_env(api_key, api_base, default_model) self._setup_env(api_key, api_base, default_model)
if api_base: if api_base:
litellm.api_base = api_base litellm.api_base = api_base
# Disable LiteLLM logging noise # Disable LiteLLM logging noise
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
litellm.drop_params = True litellm.drop_params = True
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
"""Set environment variables based on detected provider.""" """Set environment variables based on detected provider."""
spec = self._gateway or find_by_model(model) spec = self._gateway or find_by_model(model)
@@ -85,7 +83,7 @@ class LiteLLMProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key) resolved = env_val.replace("{api_key}", api_key)
resolved = resolved.replace("{api_base}", effective_base) resolved = resolved.replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved) os.environ.setdefault(env_name, resolved)
def _resolve_model(self, model: str) -> str: def _resolve_model(self, model: str) -> str:
"""Resolve model name by applying provider/gateway prefixes.""" """Resolve model name by applying provider/gateway prefixes."""
if self._gateway: if self._gateway:
@@ -96,7 +94,7 @@ class LiteLLMProvider(LLMProvider):
if prefix and not model.startswith(f"{prefix}/"): if prefix and not model.startswith(f"{prefix}/"):
model = f"{prefix}/{model}" model = f"{prefix}/{model}"
return model return model
# Standard mode: auto-prefix for known providers # Standard mode: auto-prefix for known providers
spec = find_by_model(model) spec = find_by_model(model)
if spec and spec.litellm_prefix: if spec and spec.litellm_prefix:
@@ -115,7 +113,7 @@ class LiteLLMProvider(LLMProvider):
if prefix.lower().replace("-", "_") != spec_name: if prefix.lower().replace("-", "_") != spec_name:
return model return model
return f"{canonical_prefix}/{remainder}" return f"{canonical_prefix}/{remainder}"
def _supports_cache_control(self, model: str) -> bool: def _supports_cache_control(self, model: str) -> bool:
"""Return True when the provider supports cache_control on content blocks.""" """Return True when the provider supports cache_control on content blocks."""
if self._gateway is not None: if self._gateway is not None:
@@ -158,7 +156,7 @@ class LiteLLMProvider(LLMProvider):
if pattern in model_lower: if pattern in model_lower:
kwargs.update(overrides) kwargs.update(overrides)
return return
@staticmethod @staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key.""" """Strip non-standard keys and ensure assistant messages have a content key."""
@@ -181,14 +179,14 @@ class LiteLLMProvider(LLMProvider):
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request via LiteLLM. Send a chat completion request via LiteLLM.
Args: Args:
messages: List of message dicts with 'role' and 'content'. messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format. tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5'). model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response. max_tokens: Maximum tokens in response.
temperature: Sampling temperature. temperature: Sampling temperature.
Returns: Returns:
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
""" """
@@ -201,33 +199,33 @@ class LiteLLMProvider(LLMProvider):
# Clamp max_tokens to at least 1 — negative or zero values cause # Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1". # LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens) max_tokens = max(1, max_tokens)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model, "model": model,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)), "messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
} }
# Apply model-specific overrides (e.g. kimi-k2.5 temperature) # Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs) self._apply_model_overrides(model, kwargs)
# Pass api_key directly — more reliable than env vars alone # Pass api_key directly — more reliable than env vars alone
if self.api_key: if self.api_key:
kwargs["api_key"] = self.api_key kwargs["api_key"] = self.api_key
# Pass api_base for custom endpoints # Pass api_base for custom endpoints
if self.api_base: if self.api_base:
kwargs["api_base"] = self.api_base kwargs["api_base"] = self.api_base
# Pass extra headers (e.g. APP-Code for AiHubMix) # Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers: if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers kwargs["extra_headers"] = self.extra_headers
if tools: if tools:
kwargs["tools"] = tools kwargs["tools"] = tools
kwargs["tool_choice"] = "auto" kwargs["tool_choice"] = "auto"
try: try:
response = await acompletion(**kwargs) response = await acompletion(**kwargs)
return self._parse_response(response) return self._parse_response(response)
@@ -237,12 +235,12 @@ class LiteLLMProvider(LLMProvider):
content=f"Error calling LLM: {str(e)}", content=f"Error calling LLM: {str(e)}",
finish_reason="error", finish_reason="error",
) )
def _parse_response(self, response: Any) -> LLMResponse: def _parse_response(self, response: Any) -> LLMResponse:
"""Parse LiteLLM response into our standard format.""" """Parse LiteLLM response into our standard format."""
choice = response.choices[0] choice = response.choices[0]
message = choice.message message = choice.message
tool_calls = [] tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls: if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls: for tc in message.tool_calls:
@@ -250,13 +248,13 @@ class LiteLLMProvider(LLMProvider):
args = tc.function.arguments args = tc.function.arguments
if isinstance(args, str): if isinstance(args, str):
args = json_repair.loads(args) args = json_repair.loads(args)
tool_calls.append(ToolCallRequest( tool_calls.append(ToolCallRequest(
id=_short_tool_id(), id=_short_tool_id(),
name=tc.function.name, name=tc.function.name,
arguments=args, arguments=args,
)) ))
usage = {} usage = {}
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
usage = { usage = {
@@ -264,9 +262,9 @@ class LiteLLMProvider(LLMProvider):
"completion_tokens": response.usage.completion_tokens, "completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens,
} }
reasoning_content = getattr(message, "reasoning_content", None) or None reasoning_content = getattr(message, "reasoning_content", None) or None
return LLMResponse( return LLMResponse(
content=message.content, content=message.content,
tool_calls=tool_calls, tool_calls=tool_calls,
@@ -274,7 +272,7 @@ class LiteLLMProvider(LLMProvider):
usage=usage, usage=usage,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
) )
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model.""" """Get the default model."""
return self.default_model return self.default_model

View File

@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
import httpx import httpx
from loguru import logger from loguru import logger
from oauth_cli_kit import get_token as get_codex_token from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses" DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"

View File

@@ -2,7 +2,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any
import httpx import httpx
from loguru import logger from loguru import logger
@@ -11,33 +10,33 @@ from loguru import logger
class GroqTranscriptionProvider: class GroqTranscriptionProvider:
""" """
Voice transcription provider using Groq's Whisper API. Voice transcription provider using Groq's Whisper API.
Groq offers extremely fast transcription with a generous free tier. Groq offers extremely fast transcription with a generous free tier.
""" """
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.environ.get("GROQ_API_KEY") self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions" self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
async def transcribe(self, file_path: str | Path) -> str: async def transcribe(self, file_path: str | Path) -> str:
""" """
Transcribe an audio file using Groq. Transcribe an audio file using Groq.
Args: Args:
file_path: Path to the audio file. file_path: Path to the audio file.
Returns: Returns:
Transcribed text. Transcribed text.
""" """
if not self.api_key: if not self.api_key:
logger.warning("Groq API key not configured for transcription") logger.warning("Groq API key not configured for transcription")
return "" return ""
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
logger.error("Audio file not found: {}", file_path) logger.error("Audio file not found: {}", file_path)
return "" return ""
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
with open(path, "rb") as f: with open(path, "rb") as f:
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
} }
response = await client.post( response = await client.post(
self.api_url, self.api_url,
headers=headers, headers=headers,
files=files, files=files,
timeout=60.0 timeout=60.0
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
return data.get("text", "") return data.get("text", "")
except Exception as e: except Exception as e:
logger.error("Groq transcription error: {}", e) logger.error("Groq transcription error: {}", e)
return "" return ""

View File

@@ -1,5 +1,5 @@
"""Session management module.""" """Session management module."""
from nanobot.session.manager import SessionManager, Session from nanobot.session.manager import Session, SessionManager
__all__ = ["SessionManager", "Session"] __all__ = ["SessionManager", "Session"]

View File

@@ -2,9 +2,9 @@
import json import json
import shutil import shutil
from pathlib import Path
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@@ -30,7 +30,7 @@ class Session:
updated_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now)
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
last_consolidated: int = 0 # Number of messages already consolidated to files last_consolidated: int = 0 # Number of messages already consolidated to files
def add_message(self, role: str, content: str, **kwargs: Any) -> None: def add_message(self, role: str, content: str, **kwargs: Any) -> None:
"""Add a message to the session.""" """Add a message to the session."""
msg = { msg = {
@@ -41,7 +41,7 @@ class Session:
} }
self.messages.append(msg) self.messages.append(msg)
self.updated_at = datetime.now() self.updated_at = datetime.now()
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a user turn.""" """Return unconsolidated messages for LLM input, aligned to a user turn."""
unconsolidated = self.messages[self.last_consolidated:] unconsolidated = self.messages[self.last_consolidated:]
@@ -61,7 +61,7 @@ class Session:
entry[k] = m[k] entry[k] = m[k]
out.append(entry) out.append(entry)
return out return out
def clear(self) -> None: def clear(self) -> None:
"""Clear all messages and reset session to initial state.""" """Clear all messages and reset session to initial state."""
self.messages = [] self.messages = []
@@ -81,7 +81,7 @@ class SessionManager:
self.sessions_dir = ensure_dir(self.workspace / "sessions") self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
self._cache: dict[str, Session] = {} self._cache: dict[str, Session] = {}
def _get_session_path(self, key: str) -> Path: def _get_session_path(self, key: str) -> Path:
"""Get the file path for a session.""" """Get the file path for a session."""
safe_key = safe_filename(key.replace(":", "_")) safe_key = safe_filename(key.replace(":", "_"))
@@ -91,27 +91,27 @@ class SessionManager:
"""Legacy global session path (~/.nanobot/sessions/).""" """Legacy global session path (~/.nanobot/sessions/)."""
safe_key = safe_filename(key.replace(":", "_")) safe_key = safe_filename(key.replace(":", "_"))
return self.legacy_sessions_dir / f"{safe_key}.jsonl" return self.legacy_sessions_dir / f"{safe_key}.jsonl"
def get_or_create(self, key: str) -> Session: def get_or_create(self, key: str) -> Session:
""" """
Get an existing session or create a new one. Get an existing session or create a new one.
Args: Args:
key: Session key (usually channel:chat_id). key: Session key (usually channel:chat_id).
Returns: Returns:
The session. The session.
""" """
if key in self._cache: if key in self._cache:
return self._cache[key] return self._cache[key]
session = self._load(key) session = self._load(key)
if session is None: if session is None:
session = Session(key=key) session = Session(key=key)
self._cache[key] = session self._cache[key] = session
return session return session
def _load(self, key: str) -> Session | None: def _load(self, key: str) -> Session | None:
"""Load a session from disk.""" """Load a session from disk."""
path = self._get_session_path(key) path = self._get_session_path(key)
@@ -158,7 +158,7 @@ class SessionManager:
except Exception as e: except Exception as e:
logger.warning("Failed to load session {}: {}", key, e) logger.warning("Failed to load session {}: {}", key, e)
return None return None
def save(self, session: Session) -> None: def save(self, session: Session) -> None:
"""Save a session to disk.""" """Save a session to disk."""
path = self._get_session_path(session.key) path = self._get_session_path(session.key)
@@ -177,20 +177,20 @@ class SessionManager:
f.write(json.dumps(msg, ensure_ascii=False) + "\n") f.write(json.dumps(msg, ensure_ascii=False) + "\n")
self._cache[session.key] = session self._cache[session.key] = session
def invalidate(self, key: str) -> None: def invalidate(self, key: str) -> None:
"""Remove a session from the in-memory cache.""" """Remove a session from the in-memory cache."""
self._cache.pop(key, None) self._cache.pop(key, None)
def list_sessions(self) -> list[dict[str, Any]]: def list_sessions(self) -> list[dict[str, Any]]:
""" """
List all sessions. List all sessions.
Returns: Returns:
List of session info dicts. List of session info dicts.
""" """
sessions = [] sessions = []
for path in self.sessions_dir.glob("*.jsonl"): for path in self.sessions_dir.glob("*.jsonl"):
try: try:
# Read just the metadata line # Read just the metadata line
@@ -208,5 +208,5 @@ class SessionManager:
}) })
except Exception: except Exception:
continue continue
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True) return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)

View File

@@ -1,5 +1,5 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
from nanobot.utils.helpers import ensure_dir, get_workspace_path, get_data_path from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"] __all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]

View File

@@ -1,8 +1,8 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
import re import re
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
def ensure_dir(path: Path) -> Path: def ensure_dir(path: Path) -> Path: