Files
nanobot/nanobot/auth/codex/server.py
2026-02-05 17:53:00 +08:00

116 lines
3.7 KiB
Python

"""Local OAuth callback server."""
from __future__ import annotations
import socket
import threading
import urllib.parse
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Callable
from nanobot.auth.codex.constants import SUCCESS_HTML
class _OAuthHandler(BaseHTTPRequestHandler):
"""Local callback HTTP handler."""
server_version = "NanobotOAuth/1.0"
protocol_version = "HTTP/1.1"
def do_GET(self) -> None: # noqa: N802
try:
url = urllib.parse.urlparse(self.path)
if url.path != "/auth/callback":
self.send_response(404)
self.end_headers()
self.wfile.write(b"Not found")
return
qs = urllib.parse.parse_qs(url.query)
code = qs.get("code", [None])[0]
state = qs.get("state", [None])[0]
if state != self.server.expected_state:
self.send_response(400)
self.end_headers()
self.wfile.write(b"State mismatch")
return
if not code:
self.send_response(400)
self.end_headers()
self.wfile.write(b"Missing code")
return
self.server.code = code
try:
if getattr(self.server, "on_code", None):
self.server.on_code(code)
except Exception:
pass
body = SUCCESS_HTML.encode("utf-8")
self.send_response(200)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.send_header("Connection", "close")
self.end_headers()
self.wfile.write(body)
try:
self.wfile.flush()
except Exception:
pass
self.close_connection = True
except Exception:
self.send_response(500)
self.end_headers()
self.wfile.write(b"Internal error")
def log_message(self, format: str, *args: Any) -> None: # noqa: A003
# Suppress default logs to avoid noisy output.
return
class _OAuthServer(HTTPServer):
"""OAuth callback server with state."""
def __init__(
self,
server_address: tuple[str, int],
expected_state: str,
on_code: Callable[[str], None] | None = None,
):
super().__init__(server_address, _OAuthHandler)
self.expected_state = expected_state
self.code: str | None = None
self.on_code = on_code
def _start_local_server(
state: str,
on_code: Callable[[str], None] | None = None,
) -> tuple[_OAuthServer | None, str | None]:
"""Start a local OAuth callback server on the first available localhost address."""
try:
addrinfos = socket.getaddrinfo("localhost", 1455, type=socket.SOCK_STREAM)
except OSError as exc:
return None, f"Failed to resolve localhost: {exc}"
last_error: OSError | None = None
for family, _socktype, _proto, _canonname, sockaddr in addrinfos:
try:
# Support IPv4/IPv6 to avoid missing callbacks when localhost resolves to ::1.
class _AddrOAuthServer(_OAuthServer):
address_family = family
server = _AddrOAuthServer(sockaddr, state, on_code=on_code)
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
return server, None
except OSError as exc:
last_error = exc
continue
if last_error:
return None, f"Local callback server failed to start: {last_error}"
return None, "Local callback server failed to start: unknown error"