Files
CaptchBreaker/server.py

823 lines
27 KiB
Python

"""
FastAPI HTTP 推理服务 (纯推理,不依赖 torch/训练代码)
仅依赖: fastapi, uvicorn, python-multipart, onnxruntime, pillow, numpy
API:
POST /solve JSON base64 图片识别 (同步)
POST /solve/upload multipart 文件上传识别 (同步)
POST /createTask 创建异步识别任务
POST /getTaskResult 查询异步任务结果
POST /getBalance 查询本地服务余额占位值
POST /api/v1/* 兼容别名
GET /health 健康检查
GET /api/v1/health 健康检查兼容别名
启动:
uv sync --extra server
uv run captcha serve --port 8080
"""
from __future__ import annotations
import base64
import binascii
import hashlib
import hmac
import json
import logging
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Callable
from urllib.error import URLError
from urllib.parse import urlencode
from urllib.request import Request as UrlRequest, urlopen
from config import CAPTCHA_TYPES, FUN_CAPTCHA_TASKS, SERVER_CONFIG
logger = logging.getLogger(__name__)
ASYNC_TASK_TYPES = {
"ImageToTextTask",
"ImageToTextTaskM1",
"ImageToTextTaskMuggle",
"ImageToTextTaskProxyless",
"CaptchaImageTask",
"FunCaptcha",
"FunCaptchaTask",
"FunCaptchaTaskProxyless",
}
class ServiceError(Exception):
"""服务内部可预期错误。"""
def __init__(self, code: str, description: str, status_code: int = 400):
super().__init__(description)
self.code = code
self.description = description
self.status_code = status_code
@dataclass
class TaskRecord:
task_id: str
created_at: int
expires_at: int
client_ip: str | None = None
task_type: str | None = None
captcha_type: str | None = None
question: str | None = None
callback_url: str | None = None
callback_attempts: int = 0
callback_delivered_at: int | None = None
callback_last_error: str | None = None
status: str = "processing"
result: dict | None = None
error_code: str | None = None
error_description: str | None = None
completed_at: int | None = None
class TaskManager:
"""简单的内存任务管理器,适合单进程部署。"""
def __init__(
self,
solve_fn: Callable[[bytes, str | None, str | None], dict],
ttl_seconds: int,
max_workers: int,
tasks_dir: str | Path,
):
self._solve_fn = solve_fn
self._ttl_seconds = ttl_seconds
self._tasks_dir = Path(tasks_dir)
self._tasks_dir.mkdir(parents=True, exist_ok=True)
self._executor = ThreadPoolExecutor(
max_workers=max_workers,
thread_name_prefix="captcha-task",
)
self._tasks: dict[str, TaskRecord] = {}
self._lock = threading.Lock()
self._load_tasks()
def create_task(
self,
image_bytes: bytes,
captcha_type: str | None,
question: str | None = None,
client_ip: str | None = None,
task_type: str | None = None,
callback_url: str | None = None,
) -> str:
now = int(time.time())
task_id = uuid.uuid4().hex
record = TaskRecord(
task_id=task_id,
created_at=now,
expires_at=now + self._ttl_seconds,
client_ip=client_ip,
task_type=task_type,
captcha_type=captcha_type,
question=question,
callback_url=callback_url,
)
with self._lock:
self._cleanup_locked(now)
self._tasks[task_id] = record
self._persist_task_locked(record)
self._executor.submit(self._run_task, task_id, image_bytes, captcha_type, question)
return task_id
def get_task(self, task_id: str) -> TaskRecord | None:
with self._lock:
self._cleanup_locked()
return self._tasks.get(task_id)
def stats(self) -> dict:
with self._lock:
self._cleanup_locked()
processing = sum(1 for task in self._tasks.values() if task.status == "processing")
ready = sum(1 for task in self._tasks.values() if task.status == "ready")
failed = sum(1 for task in self._tasks.values() if task.status == "failed")
return {
"active": len(self._tasks),
"processing": processing,
"ready": ready,
"failed": failed,
"ttl_seconds": self._ttl_seconds,
}
def shutdown(self):
self._executor.shutdown(wait=True, cancel_futures=False)
def _run_task(
self,
task_id: str,
image_bytes: bytes,
captcha_type: str | None,
question: str | None,
):
try:
result = self._solve_fn(image_bytes, captcha_type, question)
except ServiceError as exc:
task = self._mark_failed(task_id, exc.code, exc.description)
except Exception as exc: # pragma: no cover - 防御性兜底
task = self._mark_failed(task_id, "ERROR_TASK_FAILED", str(exc))
else:
task = self._mark_ready(task_id, result)
if task and task.callback_url:
self._send_task_callback(task)
def _mark_ready(self, task_id: str, result: dict) -> TaskRecord | None:
with self._lock:
record = self._tasks.get(task_id)
if record is None:
return None
record.status = "ready"
record.result = result
record.completed_at = int(time.time())
self._persist_task_locked(record)
return record
def _mark_failed(self, task_id: str, code: str, description: str) -> TaskRecord | None:
with self._lock:
record = self._tasks.get(task_id)
if record is None:
return None
record.status = "failed"
record.error_code = code
record.error_description = description
record.completed_at = int(time.time())
self._persist_task_locked(record)
return record
def _cleanup_locked(self, now: int | None = None):
now = now or int(time.time())
expired_ids = [
task_id
for task_id, task in self._tasks.items()
if task.expires_at <= now
]
for task_id in expired_ids:
self._tasks.pop(task_id, None)
self._delete_task_file(task_id)
def _update_callback_state(
self,
task_id: str,
*,
attempts: int | None = None,
delivered_at: int | None = None,
last_error: str | None = None,
):
with self._lock:
record = self._tasks.get(task_id)
if record is None:
return
if attempts is not None:
record.callback_attempts = attempts
if delivered_at is not None:
record.callback_delivered_at = delivered_at
record.callback_last_error = last_error
self._persist_task_locked(record)
def _task_path(self, task_id: str) -> Path:
return self._tasks_dir / f"{task_id}.json"
def _persist_task_locked(self, record: TaskRecord):
path = self._task_path(record.task_id)
tmp_path = path.with_suffix(".json.tmp")
tmp_path.write_text(
json.dumps(asdict(record), ensure_ascii=False, sort_keys=True),
encoding="utf-8",
)
tmp_path.replace(path)
def _delete_task_file(self, task_id: str):
self._task_path(task_id).unlink(missing_ok=True)
def _load_tasks(self):
now = int(time.time())
for path in sorted(self._tasks_dir.glob("*.json")):
try:
task_data = json.loads(path.read_text(encoding="utf-8"))
task = TaskRecord(**task_data)
except Exception as exc: # pragma: no cover - 防御性兜底
logger.warning("skip invalid task file: path=%s err=%s", path, exc)
path.unlink(missing_ok=True)
continue
if task.expires_at <= now:
path.unlink(missing_ok=True)
continue
self._tasks[task.task_id] = task
def _send_task_callback(self, task: TaskRecord):
max_retries = max(0, int(SERVER_CONFIG.get("callback_max_retries", 0)))
delay = max(0.0, float(SERVER_CONFIG.get("callback_retry_delay_seconds", 0.0)))
backoff = max(1.0, float(SERVER_CONFIG.get("callback_retry_backoff", 1.0)))
payload = _build_task_callback_payload(task)
for attempt in range(max_retries + 1):
attempt_no = attempt + 1
try:
_post_callback(task.callback_url, payload)
self._update_callback_state(
task.task_id,
attempts=attempt_no,
delivered_at=int(time.time()),
last_error=None,
)
return
except (URLError, OSError, ValueError) as exc:
self._update_callback_state(
task.task_id,
attempts=attempt_no,
last_error=str(exc),
)
if attempt >= max_retries:
logger.warning(
"task callback failed: task_id=%s url=%s attempts=%s err=%s",
task.task_id,
task.callback_url,
attempt_no,
exc,
)
return
logger.warning(
"task callback retry: task_id=%s url=%s attempt=%s err=%s",
task.task_id,
task.callback_url,
attempt_no,
exc,
)
if delay > 0:
time.sleep(delay)
delay *= backoff
def _task_success_payload(**extra) -> dict:
payload = {"errorId": 0}
payload.update(extra)
return payload
def _task_error_payload(code: str, description: str, **extra) -> dict:
payload = {
"errorId": 1,
"errorCode": code,
"errorDescription": description,
}
payload.update(extra)
return payload
def _normalize_captcha_type(captcha_type: str | None) -> str | None:
if captcha_type in (None, ""):
return None
if captcha_type not in CAPTCHA_TYPES:
raise ServiceError(
"ERROR_BAD_CAPTCHA_TYPE",
f"不支持的验证码类型: {captcha_type}",
status_code=400,
)
return captcha_type
def _normalize_question(question: str | None) -> str | None:
if question in (None, ""):
return None
if question not in FUN_CAPTCHA_TASKS:
raise ServiceError(
"ERROR_TASK_QUESTION_NOT_SUPPORTED",
f"不支持的专项任务 question: {question}",
status_code=400,
)
return question
def _decode_image_b64(encoded: str) -> bytes:
if not encoded:
raise ServiceError("ERROR_EMPTY_IMAGE", "空图片", status_code=400)
if encoded.startswith("data:") and "," in encoded:
encoded = encoded.split(",", 1)[1]
encoded = "".join(encoded.split())
encoded += "=" * (-len(encoded) % 4)
try:
return base64.b64decode(encoded, altchars=b"-_", validate=True)
except (binascii.Error, ValueError) as exc:
raise ServiceError(
"ERROR_BAD_IMAGE_BASE64",
"base64 解码失败",
status_code=400,
) from exc
def _validate_client_key(client_key: str | None):
expected = SERVER_CONFIG.get("client_key")
if not expected:
return
if client_key != expected:
raise ServiceError(
"ERROR_KEY_DOES_NOT_EXIST",
"clientKey 无效",
status_code=401,
)
def _build_task_solution(task: TaskRecord) -> dict:
if task.result is None:
return {}
if "objects" in task.result:
objects = list(task.result["objects"])
primary = objects[0] if objects else None
solution = {
"objects": objects,
"answer": primary,
"raw": task.result.get("raw", "" if primary is None else str(primary)),
"timeMs": task.result["time_ms"],
}
if task.question:
solution["question"] = task.question
if primary is not None:
solution["text"] = str(primary)
return solution
return {
"text": task.result["result"],
"answer": task.result["result"],
"raw": task.result["raw"],
"captchaType": task.result["type"],
"timeMs": task.result["time_ms"],
}
def _build_task_meta(task: TaskRecord) -> dict:
payload = {
"type": task.task_type,
"captchaType": task.captcha_type,
}
if task.question is not None:
payload["question"] = task.question
return payload
def _build_task_result_payload(task: TaskRecord) -> dict:
if task.status == "processing":
return _task_success_payload(
taskId=task.task_id,
status="processing",
createTime=task.created_at,
expiresAt=task.expires_at,
task=_build_task_meta(task),
callback={
"configured": bool(task.callback_url),
"url": task.callback_url,
"attempts": task.callback_attempts,
"delivered": False,
"deliveredAt": task.callback_delivered_at,
"lastError": task.callback_last_error,
},
)
if task.status == "failed":
return _task_error_payload(
task.error_code or "ERROR_TASK_FAILED",
task.error_description or "任务执行失败",
taskId=task.task_id,
status="failed",
ip=task.client_ip,
createTime=task.created_at,
endTime=task.completed_at,
expiresAt=task.expires_at,
task=_build_task_meta(task),
callback={
"configured": bool(task.callback_url),
"url": task.callback_url,
"attempts": task.callback_attempts,
"delivered": task.callback_delivered_at is not None,
"deliveredAt": task.callback_delivered_at,
"lastError": task.callback_last_error,
},
)
return _task_success_payload(
taskId=task.task_id,
status="ready",
solution=_build_task_solution(task),
cost=f"{float(SERVER_CONFIG['task_cost']):.5f}",
ip=task.client_ip,
createTime=task.created_at,
endTime=task.completed_at,
expiresAt=task.expires_at,
solveCount=1,
task=_build_task_meta(task),
callback={
"configured": bool(task.callback_url),
"url": task.callback_url,
"attempts": task.callback_attempts,
"delivered": task.callback_delivered_at is not None,
"deliveredAt": task.callback_delivered_at,
"lastError": task.callback_last_error,
},
)
def _build_task_callback_payload(task: TaskRecord) -> dict[str, str]:
payload = {
"id": task.task_id,
"taskId": task.task_id,
"status": task.status,
"errorId": "0" if task.status == "ready" else "1",
}
if task.status == "ready":
solution = _build_task_solution(task)
payload["code"] = str(solution.get("answer", ""))
payload["answer"] = str(solution.get("answer", ""))
payload["raw"] = str(solution.get("raw", ""))
payload["timeMs"] = str(solution["timeMs"])
payload["cost"] = f"{float(SERVER_CONFIG['task_cost']):.5f}"
if "text" in solution:
payload["text"] = str(solution["text"])
if "captchaType" in solution:
payload["captchaType"] = str(solution["captchaType"])
if "objects" in solution:
payload["objects"] = json.dumps(solution["objects"], ensure_ascii=False)
if "question" in solution:
payload["question"] = str(solution["question"])
else:
payload.update(
{
"errorCode": task.error_code or "ERROR_TASK_FAILED",
"errorDescription": task.error_description or "任务执行失败",
}
)
return payload
def _sign_callback_payload(data: bytes, timestamp: str, secret: str) -> str:
message = timestamp.encode("utf-8") + b"." + data
return hmac.new(secret.encode("utf-8"), message, hashlib.sha256).hexdigest()
def _build_callback_request(callback_url: str, payload: dict[str, str]) -> UrlRequest:
data = urlencode(payload).encode("utf-8")
headers = {"Content-Type": "application/x-www-form-urlencoded"}
secret = SERVER_CONFIG.get("callback_signing_secret")
if secret:
timestamp = str(int(time.time()))
headers["X-CaptchaBreaker-Timestamp"] = timestamp
headers["X-CaptchaBreaker-Signature-Alg"] = "hmac-sha256"
headers["X-CaptchaBreaker-Signature"] = _sign_callback_payload(data, timestamp, secret)
return UrlRequest(
callback_url,
data=data,
headers=headers,
method="POST",
)
def _post_callback(callback_url: str, payload: dict[str, str]):
request = _build_callback_request(callback_url, payload)
timeout = SERVER_CONFIG["callback_timeout_seconds"]
with urlopen(request, timeout=timeout) as response:
response.read()
def create_app(pipeline_factory=None, funcaptcha_factories=None):
"""
创建 FastAPI 应用实例(工厂函数)。
cli.py 的 cmd_serve 依赖此签名。
"""
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import Body, FastAPI, File, Query, Request, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from inference.fun_captcha import FunCaptchaRollballPipeline
from inference.pipeline import CaptchaPipeline
pipeline: Optional[CaptchaPipeline] = None
load_pipeline_on_startup = pipeline_factory is None
pipeline_factory = pipeline_factory or CaptchaPipeline
fun_captcha_pipelines: dict[str, object] = {}
load_fun_captcha_on_startup = funcaptcha_factories is None
if funcaptcha_factories is None:
funcaptcha_factories = {
"4_3d_rollball_animals": lambda: FunCaptchaRollballPipeline(
question="4_3d_rollball_animals"
),
}
def _infer(image_bytes: bytes, captcha_type: str | None, question: str | None) -> dict:
normalized_question = _normalize_question(question)
normalized_type = None if normalized_question is not None else _normalize_captcha_type(captcha_type)
if pipeline is None:
if normalized_question is None:
raise ServiceError(
"ERROR_NO_MODELS_LOADED",
"模型未加载,请先训练并导出 ONNX 模型",
status_code=503,
)
if not image_bytes:
raise ServiceError("ERROR_EMPTY_IMAGE", "空图片", status_code=400)
try:
if normalized_question is not None:
fun_pipeline = fun_captcha_pipelines.get(normalized_question)
if fun_pipeline is None:
raise ServiceError(
"ERROR_NO_MODELS_LOADED",
f"专项模型未加载: {normalized_question}",
status_code=503,
)
result = fun_pipeline.solve(image_bytes)
else:
result = pipeline.solve(image_bytes, captcha_type=normalized_type)
except ServiceError:
raise
except (RuntimeError, TypeError) as exc:
raise ServiceError(
"ERROR_INFERENCE_FAILED",
str(exc),
status_code=400,
) from exc
except Exception as exc:
raise ServiceError(
"ERROR_IMAGE_DECODE_FAILED",
f"图片解析失败: {exc}",
status_code=400,
) from exc
payload = {
"type": result["type"],
"result": result["result"],
"raw": result["raw"],
"time_ms": result["time_ms"],
}
if "question" in result:
payload["question"] = result["question"]
if "objects" in result:
payload["objects"] = result["objects"]
if "scores" in result:
payload["scores"] = result["scores"]
return payload
task_manager = TaskManager(
solve_fn=_infer,
ttl_seconds=SERVER_CONFIG["task_ttl_seconds"],
max_workers=SERVER_CONFIG["task_workers"],
tasks_dir=SERVER_CONFIG["tasks_dir"],
)
@asynccontextmanager
async def lifespan(_app: FastAPI):
nonlocal pipeline
if load_pipeline_on_startup:
try:
pipeline = pipeline_factory()
except FileNotFoundError:
pipeline = None
if load_fun_captcha_on_startup:
for question, factory in funcaptcha_factories.items():
try:
fun_captcha_pipelines[question] = factory()
except FileNotFoundError:
continue
try:
yield
finally:
task_manager.shutdown()
app = FastAPI(
title="CaptchaBreaker",
description="验证码识别多模型系统 - HTTP 推理服务",
version="0.2.0",
lifespan=lifespan,
)
app.state.task_manager = task_manager
if not load_pipeline_on_startup:
pipeline = pipeline_factory()
if not load_fun_captcha_on_startup:
fun_captcha_pipelines = {
question: factory()
for question, factory in funcaptcha_factories.items()
}
# ---- 请求体定义 ----
class SolveRequest(BaseModel):
image: str
type: Optional[str] = None
question: Optional[str] = None
class AsyncTaskRequest(BaseModel):
type: str
body: Optional[str] = None
image: Optional[str] = None
captchaType: Optional[str] = None
question: Optional[str] = None
class CreateTaskRequest(BaseModel):
clientKey: Optional[str] = None
callbackUrl: Optional[str] = None
softId: Optional[int] = None
languagePool: Optional[str] = None
task: AsyncTaskRequest
class GetTaskResultRequest(BaseModel):
clientKey: Optional[str] = None
taskId: str
class BalanceRequest(BaseModel):
clientKey: Optional[str] = None
# ---- 路由 ----
@app.get("/health")
@app.get("/api/v1/health")
def health():
models_loaded = pipeline is not None or bool(fun_captcha_pipelines)
return {
"status": "ok" if models_loaded else "no_models",
"models_loaded": models_loaded,
"async_tasks": task_manager.stats(),
"client_key_required": bool(SERVER_CONFIG.get("client_key")),
"supported_task_types": sorted(ASYNC_TASK_TYPES),
"supported_task_questions": sorted(FUN_CAPTCHA_TASKS),
}
@app.post("/solve")
@app.post("/api/v1/solve")
async def solve_base64(req: SolveRequest):
"""JSON 请求,图片通过 base64 传输。"""
try:
image_bytes = _decode_image_b64(req.image)
return _infer(image_bytes, req.type, getattr(req, "question", None))
except ServiceError as exc:
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.description, "error_code": exc.code},
)
@app.post("/solve/upload")
@app.post("/api/v1/solve/upload")
async def solve_upload(
image: UploadFile = File(...),
type: Optional[str] = Query(None, description="指定类型跳过分类"),
question: Optional[str] = Query(None, description="专项 question"),
):
"""multipart 文件上传。"""
try:
return _infer(await image.read(), type, question)
except ServiceError as exc:
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.description, "error_code": exc.code},
)
@app.post("/createTask")
@app.post("/api/v1/createTask")
async def create_task(req: CreateTaskRequest, request: Request):
"""创建异步识别任务,返回 taskId。"""
task = req.task
try:
_validate_client_key(req.clientKey)
except ServiceError as exc:
return _task_error_payload(exc.code, exc.description)
if task.type not in ASYNC_TASK_TYPES:
return _task_error_payload(
"ERROR_TASK_TYPE_NOT_SUPPORTED",
f"不支持的任务类型: {task.type}",
)
try:
question = _normalize_question(getattr(task, "question", None))
if question is None:
_normalize_captcha_type(task.captchaType)
image_bytes = _decode_image_b64(task.body or task.image or "")
if question is None and pipeline is None:
raise ServiceError(
"ERROR_NO_MODELS_LOADED",
"模型未加载,请先训练并导出 ONNX 模型",
status_code=503,
)
if question is not None and question not in fun_captcha_pipelines:
raise ServiceError(
"ERROR_NO_MODELS_LOADED",
f"专项模型未加载: {question}",
status_code=503,
)
except ServiceError as exc:
return _task_error_payload(exc.code, exc.description)
client = getattr(request, "client", None)
client_ip = getattr(client, "host", None)
task_id = task_manager.create_task(
image_bytes,
task.captchaType,
question=question,
client_ip=client_ip,
task_type=task.type,
callback_url=getattr(req, "callbackUrl", None),
)
task = task_manager.get_task(task_id)
return _task_success_payload(
taskId=task_id,
status="processing",
createTime=task.created_at if task else int(time.time()),
expiresAt=task.expires_at if task else None,
)
@app.post("/getTaskResult")
@app.post("/api/v1/getTaskResult")
async def get_task_result(req: GetTaskResultRequest):
"""按 taskId 查询任务状态与结果。"""
try:
_validate_client_key(req.clientKey)
except ServiceError as exc:
return _task_error_payload(exc.code, exc.description, taskId=req.taskId)
task = task_manager.get_task(str(req.taskId))
if task is None:
return _task_error_payload(
"ERROR_TASK_NOT_FOUND",
f"任务不存在或已过期: {req.taskId}",
taskId=req.taskId,
)
return _build_task_result_payload(task)
@app.post("/getBalance")
@app.post("/api/v1/getBalance")
async def get_balance(req: BalanceRequest | None = Body(default=None)):
try:
_validate_client_key(getattr(req, "clientKey", None))
except ServiceError as exc:
return _task_error_payload(exc.code, exc.description)
return _task_success_payload(balance=SERVER_CONFIG["balance"])
return app