275 lines
8.8 KiB
Python
275 lines
8.8 KiB
Python
"""Codex OAuth login and token management."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import sys
|
|
import threading
|
|
import time
|
|
import urllib.parse
|
|
import webbrowser
|
|
from typing import Callable
|
|
|
|
import httpx
|
|
|
|
from nanobot.auth.codex.constants import (
|
|
AUTHORIZE_URL,
|
|
CLIENT_ID,
|
|
DEFAULT_ORIGINATOR,
|
|
REDIRECT_URI,
|
|
SCOPE,
|
|
TOKEN_URL,
|
|
)
|
|
from nanobot.auth.codex.models import CodexToken
|
|
from nanobot.auth.codex.pkce import (
|
|
_create_state,
|
|
_decode_account_id,
|
|
_generate_pkce,
|
|
_parse_authorization_input,
|
|
_parse_token_payload,
|
|
)
|
|
from nanobot.auth.codex.server import _start_local_server
|
|
from nanobot.auth.codex.storage import (
|
|
_FileLock,
|
|
_get_token_path,
|
|
_load_token_file,
|
|
_save_token_file,
|
|
_try_import_codex_cli_token,
|
|
)
|
|
|
|
|
|
async def _exchange_code_for_token_async(code: str, verifier: str) -> CodexToken:
|
|
data = {
|
|
"grant_type": "authorization_code",
|
|
"client_id": CLIENT_ID,
|
|
"code": code,
|
|
"code_verifier": verifier,
|
|
"redirect_uri": REDIRECT_URI,
|
|
}
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
TOKEN_URL,
|
|
data=data,
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
)
|
|
if response.status_code != 200:
|
|
raise RuntimeError(f"Token exchange failed: {response.status_code} {response.text}")
|
|
|
|
payload = response.json()
|
|
access, refresh, expires_in = _parse_token_payload(payload, "Token response missing fields")
|
|
|
|
account_id = _decode_account_id(access)
|
|
return CodexToken(
|
|
access=access,
|
|
refresh=refresh,
|
|
expires=int(time.time() * 1000 + expires_in * 1000),
|
|
account_id=account_id,
|
|
)
|
|
|
|
|
|
def _refresh_token(refresh_token: str) -> CodexToken:
|
|
data = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": CLIENT_ID,
|
|
}
|
|
with httpx.Client(timeout=30.0) as client:
|
|
response = client.post(TOKEN_URL, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"})
|
|
if response.status_code != 200:
|
|
raise RuntimeError(f"Token refresh failed: {response.status_code} {response.text}")
|
|
|
|
payload = response.json()
|
|
access, refresh, expires_in = _parse_token_payload(payload, "Token refresh response missing fields")
|
|
|
|
account_id = _decode_account_id(access)
|
|
return CodexToken(
|
|
access=access,
|
|
refresh=refresh,
|
|
expires=int(time.time() * 1000 + expires_in * 1000),
|
|
account_id=account_id,
|
|
)
|
|
|
|
|
|
def get_codex_token() -> CodexToken:
|
|
"""Get an available token (refresh if needed)."""
|
|
token = _load_token_file() or _try_import_codex_cli_token()
|
|
if not token:
|
|
raise RuntimeError("Codex OAuth credentials not found. Please run the login command.")
|
|
|
|
# Refresh 60 seconds early.
|
|
now_ms = int(time.time() * 1000)
|
|
if token.expires - now_ms > 60 * 1000:
|
|
return token
|
|
|
|
lock_path = _get_token_path().with_suffix(".lock")
|
|
with _FileLock(lock_path):
|
|
# Re-read to avoid stale token if another process refreshed it.
|
|
token = _load_token_file() or token
|
|
now_ms = int(time.time() * 1000)
|
|
if token.expires - now_ms > 60 * 1000:
|
|
return token
|
|
try:
|
|
refreshed = _refresh_token(token.refresh)
|
|
_save_token_file(refreshed)
|
|
return refreshed
|
|
except Exception:
|
|
# If refresh fails, re-read the file to avoid false negatives.
|
|
latest = _load_token_file()
|
|
if latest and latest.expires - now_ms > 0:
|
|
return latest
|
|
raise
|
|
|
|
|
|
async def _read_stdin_line() -> str:
|
|
loop = asyncio.get_running_loop()
|
|
if hasattr(loop, "add_reader") and sys.stdin:
|
|
future: asyncio.Future[str] = loop.create_future()
|
|
|
|
def _on_readable() -> None:
|
|
line = sys.stdin.readline()
|
|
if not future.done():
|
|
future.set_result(line)
|
|
|
|
try:
|
|
loop.add_reader(sys.stdin, _on_readable)
|
|
except Exception:
|
|
return await loop.run_in_executor(None, sys.stdin.readline)
|
|
|
|
try:
|
|
return await future
|
|
finally:
|
|
try:
|
|
loop.remove_reader(sys.stdin)
|
|
except Exception:
|
|
pass
|
|
|
|
return await loop.run_in_executor(None, sys.stdin.readline)
|
|
|
|
|
|
async def _await_manual_input(print_fn: Callable[[str], None]) -> str:
|
|
print_fn("[cyan]Paste the authorization code (or full redirect URL), or wait for the browser callback:[/cyan]")
|
|
return await _read_stdin_line()
|
|
|
|
|
|
def login_codex_oauth_interactive(
|
|
print_fn: Callable[[str], None],
|
|
prompt_fn: Callable[[str], str],
|
|
originator: str = DEFAULT_ORIGINATOR,
|
|
) -> CodexToken:
|
|
"""Interactive login flow."""
|
|
|
|
async def _login_async() -> CodexToken:
|
|
verifier, challenge = _generate_pkce()
|
|
state = _create_state()
|
|
|
|
params = {
|
|
"response_type": "code",
|
|
"client_id": CLIENT_ID,
|
|
"redirect_uri": REDIRECT_URI,
|
|
"scope": SCOPE,
|
|
"code_challenge": challenge,
|
|
"code_challenge_method": "S256",
|
|
"state": state,
|
|
"id_token_add_organizations": "true",
|
|
"codex_cli_simplified_flow": "true",
|
|
"originator": originator,
|
|
}
|
|
url = f"{AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
|
|
|
loop = asyncio.get_running_loop()
|
|
code_future: asyncio.Future[str] = loop.create_future()
|
|
|
|
def _notify(code_value: str) -> None:
|
|
if code_future.done():
|
|
return
|
|
loop.call_soon_threadsafe(code_future.set_result, code_value)
|
|
|
|
server, server_error = _start_local_server(state, on_code=_notify)
|
|
print_fn("[cyan]A browser window will open for login. If it doesn't, open this URL manually:[/cyan]")
|
|
print_fn(url)
|
|
try:
|
|
webbrowser.open(url)
|
|
except Exception:
|
|
pass
|
|
|
|
if not server and server_error:
|
|
print_fn(
|
|
"[yellow]"
|
|
f"Local callback server could not start ({server_error}). "
|
|
"You will need to paste the callback URL or authorization code."
|
|
"[/yellow]"
|
|
)
|
|
|
|
code: str | None = None
|
|
try:
|
|
if server:
|
|
print_fn("[dim]Waiting for browser callback...[/dim]")
|
|
|
|
tasks: list[asyncio.Task[object]] = []
|
|
callback_task = asyncio.create_task(asyncio.wait_for(code_future, timeout=120))
|
|
tasks.append(callback_task)
|
|
manual_task = asyncio.create_task(_await_manual_input(print_fn))
|
|
tasks.append(manual_task)
|
|
|
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
for task in done:
|
|
try:
|
|
result = task.result()
|
|
except asyncio.TimeoutError:
|
|
result = None
|
|
if not result:
|
|
continue
|
|
if task is manual_task:
|
|
parsed_code, parsed_state = _parse_authorization_input(result)
|
|
if parsed_state and parsed_state != state:
|
|
raise RuntimeError("State validation failed.")
|
|
code = parsed_code
|
|
else:
|
|
code = result
|
|
if code:
|
|
break
|
|
|
|
if not code:
|
|
prompt = "Please paste the callback URL or authorization code:"
|
|
raw = await loop.run_in_executor(None, prompt_fn, prompt)
|
|
parsed_code, parsed_state = _parse_authorization_input(raw)
|
|
if parsed_state and parsed_state != state:
|
|
raise RuntimeError("State validation failed.")
|
|
code = parsed_code
|
|
|
|
if not code:
|
|
raise RuntimeError("Authorization code not found.")
|
|
|
|
print_fn("[dim]Exchanging authorization code for tokens...[/dim]")
|
|
token = await _exchange_code_for_token_async(code, verifier)
|
|
_save_token_file(token)
|
|
return token
|
|
finally:
|
|
if server:
|
|
server.shutdown()
|
|
server.server_close()
|
|
|
|
try:
|
|
asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
return asyncio.run(_login_async())
|
|
|
|
result: list[CodexToken] = []
|
|
error: list[Exception] = []
|
|
|
|
def _runner() -> None:
|
|
try:
|
|
result.append(asyncio.run(_login_async()))
|
|
except Exception as exc:
|
|
error.append(exc)
|
|
|
|
thread = threading.Thread(target=_runner)
|
|
thread.start()
|
|
thread.join()
|
|
if error:
|
|
raise error[0]
|
|
return result[0]
|