Merge remote-tracking branch 'origin/main' into pr-1613

This commit is contained in:
Re-bin
2026-03-07 04:01:24 +00:00
12 changed files with 878 additions and 60 deletions

View File

@@ -420,6 +420,10 @@ nanobot channels login
nanobot gateway nanobot gateway
``` ```
> WhatsApp bridge updates are not applied automatically for existing installations.
> If you upgrade nanobot and need the latest WhatsApp bridge, run:
> `rm -rf ~/.nanobot/bridge && nanobot channels login`
</details> </details>
<details> <details>
@@ -671,6 +675,7 @@ Config file: `~/.nanobot/config.json`
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | | `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |

View File

@@ -9,11 +9,17 @@ import makeWASocket, {
useMultiFileAuthState, useMultiFileAuthState,
fetchLatestBaileysVersion, fetchLatestBaileysVersion,
makeCacheableSignalKeyStore, makeCacheableSignalKeyStore,
downloadMediaMessage,
extractMessageContent as baileysExtractMessageContent,
} from '@whiskeysockets/baileys'; } from '@whiskeysockets/baileys';
import { Boom } from '@hapi/boom'; import { Boom } from '@hapi/boom';
import qrcode from 'qrcode-terminal'; import qrcode from 'qrcode-terminal';
import pino from 'pino'; import pino from 'pino';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { homedir } from 'os';
import { randomBytes } from 'crypto';
const VERSION = '0.1.0'; const VERSION = '0.1.0';
@@ -24,6 +30,7 @@ export interface InboundMessage {
content: string; content: string;
timestamp: number; timestamp: number;
isGroup: boolean; isGroup: boolean;
media?: string[];
} }
export interface WhatsAppClientOptions { export interface WhatsAppClientOptions {
@@ -110,14 +117,33 @@ export class WhatsAppClient {
if (type !== 'notify') return; if (type !== 'notify') return;
for (const msg of messages) { for (const msg of messages) {
// Skip own messages
if (msg.key.fromMe) continue; if (msg.key.fromMe) continue;
// Skip status updates
if (msg.key.remoteJid === 'status@broadcast') continue; if (msg.key.remoteJid === 'status@broadcast') continue;
const content = this.extractMessageContent(msg); const unwrapped = baileysExtractMessageContent(msg.message);
if (!content) continue; if (!unwrapped) continue;
const content = this.getTextContent(unwrapped);
let fallbackContent: string | null = null;
const mediaPaths: string[] = [];
if (unwrapped.imageMessage) {
fallbackContent = '[Image]';
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.documentMessage) {
fallbackContent = '[Document]';
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
unwrapped.documentMessage.fileName ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.videoMessage) {
fallbackContent = '[Video]';
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
}
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
if (!finalContent && mediaPaths.length === 0) continue;
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false; const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
@@ -125,18 +151,45 @@ export class WhatsAppClient {
id: msg.key.id || '', id: msg.key.id || '',
sender: msg.key.remoteJid || '', sender: msg.key.remoteJid || '',
pn: msg.key.remoteJidAlt || '', pn: msg.key.remoteJidAlt || '',
content, content: finalContent,
timestamp: msg.messageTimestamp as number, timestamp: msg.messageTimestamp as number,
isGroup, isGroup,
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
}); });
} }
}); });
} }
private extractMessageContent(msg: any): string | null { private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
const message = msg.message; try {
if (!message) return null; const mediaDir = join(homedir(), '.nanobot', 'media');
await mkdir(mediaDir, { recursive: true });
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
let outFilename: string;
if (fileName) {
// Documents have a filename — use it with a unique prefix to avoid collisions
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
outFilename = prefix + fileName;
} else {
const mime = mimetype || 'application/octet-stream';
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
}
const filepath = join(mediaDir, outFilename);
await writeFile(filepath, buffer);
return filepath;
} catch (err) {
console.error('Failed to download media:', err);
return null;
}
}
private getTextContent(message: any): string | null {
// Text message // Text message
if (message.conversation) { if (message.conversation) {
return message.conversation; return message.conversation;
@@ -147,19 +200,19 @@ export class WhatsAppClient {
return message.extendedTextMessage.text; return message.extendedTextMessage.text;
} }
// Image with caption // Image with optional caption
if (message.imageMessage?.caption) { if (message.imageMessage) {
return `[Image] ${message.imageMessage.caption}`; return message.imageMessage.caption || '';
} }
// Video with caption // Video with optional caption
if (message.videoMessage?.caption) { if (message.videoMessage) {
return `[Video] ${message.videoMessage.caption}`; return message.videoMessage.caption || '';
} }
// Document with caption // Document with optional caption
if (message.documentMessage?.caption) { if (message.documentMessage) {
return `[Document] ${message.documentMessage.caption}`; return message.documentMessage.caption || '';
} }
// Voice/Audio message // Voice/Audio message

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
import time
import unicodedata
from loguru import logger from loguru import logger
from telegram import BotCommand, ReplyParameters, Update from telegram import BotCommand, ReplyParameters, Update
@@ -19,6 +21,47 @@ from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
def _strip_md(s: str) -> str:
"""Strip markdown inline formatting from text."""
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
s = re.sub(r'__(.+?)__', r'\1', s)
s = re.sub(r'~~(.+?)~~', r'\1', s)
s = re.sub(r'`([^`]+)`', r'\1', s)
return s.strip()
def _render_table_box(table_lines: list[str]) -> str:
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
def dw(s: str) -> int:
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
rows: list[list[str]] = []
has_sep = False
for line in table_lines:
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
has_sep = True
continue
rows.append(cells)
if not rows or not has_sep:
return '\n'.join(table_lines)
ncols = max(len(r) for r in rows)
for r in rows:
r.extend([''] * (ncols - len(r)))
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
def dr(cells: list[str]) -> str:
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
out = [dr(rows[0])]
out.append(' '.join('' * w for w in widths))
for row in rows[1:]:
out.append(dr(row))
return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str: def _markdown_to_telegram_html(text: str) -> str:
""" """
Convert markdown to Telegram-safe HTML. Convert markdown to Telegram-safe HTML.
@@ -34,6 +77,27 @@ def _markdown_to_telegram_html(text: str) -> str:
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text) text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
lines = text.split('\n')
rebuilt: list[str] = []
li = 0
while li < len(lines):
if re.match(r'^\s*\|.+\|', lines[li]):
tbl: list[str] = []
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
tbl.append(lines[li])
li += 1
box = _render_table_box(tbl)
if box != '\n'.join(tbl):
code_blocks.append(box)
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
else:
rebuilt.extend(tbl)
else:
rebuilt.append(lines[li])
li += 1
text = '\n'.join(rebuilt)
# 2. Extract and protect inline code # 2. Extract and protect inline code
inline_codes: list[str] = [] inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str: def save_inline_code(m: re.Match) -> str:
@@ -255,43 +319,49 @@ class TelegramChannel(BaseChannel):
# Send text content # Send text content
if msg.content and msg.content != "[empty message]": if msg.content and msg.content != "[empty message]":
is_progress = msg.metadata.get("_progress", False) is_progress = msg.metadata.get("_progress", False)
draft_id = msg.metadata.get("message_id")
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN): for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
try: # Final response: simulate streaming via draft, then persist
html = _markdown_to_telegram_html(chunk) if not is_progress:
if is_progress and draft_id: await self._send_with_streaming(chat_id, chunk, reply_params)
await self._app.bot.send_message_draft(
chat_id=chat_id,
draft_id=draft_id,
text=html,
parse_mode="HTML"
)
else: else:
await self._send_text(chat_id, chunk, reply_params)
async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None:
"""Send a plain text message with HTML fallback."""
try:
html = _markdown_to_telegram_html(text)
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id, text=html, parse_mode="HTML",
text=html, reply_parameters=reply_params,
parse_mode="HTML",
reply_parameters=reply_params
) )
except Exception as e: except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e) logger.warning("HTML parse failed, falling back to plain text: {}", e)
try: try:
if is_progress and draft_id:
await self._app.bot.send_message_draft(
chat_id=chat_id,
draft_id=draft_id,
text=chunk
)
else:
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id, text=text, reply_parameters=reply_params,
text=chunk,
reply_parameters=reply_params
) )
except Exception as e2: except Exception as e2:
logger.error("Error sending Telegram message: {}", e2) logger.error("Error sending Telegram message: {}", e2)
async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None:
"""Simulate streaming via send_message_draft, then persist with send_message."""
draft_id = int(time.time() * 1000) % (2**31)
try:
step = max(len(text) // 8, 40)
for i in range(step, len(text), step):
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text[:i],
)
await asyncio.sleep(0.04)
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text,
)
await asyncio.sleep(0.15)
except Exception:
pass
await self._send_text(chat_id, text, reply_params)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
import json import json
import mimetypes
from collections import OrderedDict from collections import OrderedDict
from loguru import logger from loguru import logger
@@ -128,10 +129,22 @@ class WhatsAppChannel(BaseChannel):
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id) logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]" content = "[Voice Message: Transcription not available for WhatsApp yet]"
# Extract media paths (images/documents/videos downloaded by the bridge)
media_paths = data.get("media") or []
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
if media_paths:
for p in media_paths:
mime, _ = mimetypes.guess_type(p)
media_type = "image" if mime and mime.startswith("image/") else "file"
media_tag = f"[{media_type}: {p}]"
content = f"{content}\n{media_tag}" if content else media_tag
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=sender, # Use full LID for replies chat_id=sender, # Use full LID for replies
content=content, content=content,
media=media_paths,
metadata={ metadata={
"message_id": message_id, "message_id": message_id,
"timestamp": data.get("timestamp"), "timestamp": data.get("timestamp"),

View File

@@ -213,6 +213,7 @@ def onboard():
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model model = config.agents.defaults.model
provider_name = config.get_provider_name(model) provider_name = config.get_provider_name(model)
@@ -231,6 +232,20 @@ def _make_provider(config: Config):
default_model=model, default_model=model,
) )
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
if provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
return AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name) spec = find_by_name(provider_name)

View File

@@ -251,6 +251,7 @@ class ProvidersConfig(Base):
"""Configuration for LLM providers.""" """Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig) anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig) openrouter: ProviderConfig = Field(default_factory=ProviderConfig)

View File

@@ -3,5 +3,6 @@
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"] __all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]

View File

@@ -0,0 +1,210 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
from __future__ import annotations
import uuid
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
"""
def __init__(
self,
api_key: str = "",
api_base: str = "",
default_model: str = "gpt-5.2-chat",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
return payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@@ -87,6 +87,20 @@ class LLMProvider(ABC):
result.append(msg) result.append(msg)
return result return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,

View File

@@ -1,5 +1,6 @@
"""LiteLLM provider implementation for multi-provider support.""" """LiteLLM provider implementation for multi-provider support."""
import hashlib
import os import os
import secrets import secrets
import string import string
@@ -166,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
return _ANTHROPIC_EXTRA_KEYS return _ANTHROPIC_EXTRA_KEYS
return frozenset() return frozenset()
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
@staticmethod @staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key.""" """Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = [] sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
for msg in messages: id_map: dict[str, str] = {}
clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls def map_id(value: Any) -> Any:
if clean.get("role") == "assistant" and "content" not in clean: if not isinstance(value, str):
clean["content"] = None return value
sanitized.append(clean) return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
normalized_tool_calls = []
for tc in clean["tool_calls"]:
if not isinstance(tc, dict):
normalized_tool_calls.append(tc)
continue
tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized_tool_calls.append(tc_clean)
clean["tool_calls"] = normalized_tool_calls
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized return sanitized
async def chat( async def chat(

View File

@@ -79,6 +79,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
litellm_prefix="", litellm_prefix="",
is_direct=True, is_direct=True,
), ),
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
ProviderSpec(
name="azure_openai",
keywords=("azure", "azure-openai"),
env_key="",
display_name="Azure OpenAI",
litellm_prefix="",
is_direct=True,
),
# === Gateways (detected by api_key / api_base, not model name) ========= # === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback. # Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-" # OpenRouter: global gateway, keys start with "sk-or-"

View File

@@ -0,0 +1,399 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert result.finish_reason == "error"
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")