Align task API and add FunCaptcha support
This commit is contained in:
816
server.py
816
server.py
@@ -4,117 +4,819 @@ FastAPI HTTP 推理服务 (纯推理,不依赖 torch/训练代码)
|
||||
仅依赖: fastapi, uvicorn, python-multipart, onnxruntime, pillow, numpy
|
||||
|
||||
API:
|
||||
POST /solve JSON base64 图片识别
|
||||
POST /solve/upload multipart 文件上传识别
|
||||
GET /health 健康检查
|
||||
POST /solve JSON base64 图片识别 (同步)
|
||||
POST /solve/upload multipart 文件上传识别 (同步)
|
||||
POST /createTask 创建异步识别任务
|
||||
POST /getTaskResult 查询异步任务结果
|
||||
POST /getBalance 查询本地服务余额占位值
|
||||
POST /api/v1/* 兼容别名
|
||||
GET /health 健康检查
|
||||
GET /api/v1/health 健康检查兼容别名
|
||||
|
||||
启动:
|
||||
uv sync --extra server
|
||||
python cli.py serve --port 8080
|
||||
|
||||
请求示例 (base64):
|
||||
curl -X POST http://localhost:8080/solve \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"image": "<base64>", "type": "normal"}'
|
||||
|
||||
请求示例 (文件上传):
|
||||
curl -X POST http://localhost:8080/solve/upload -F "image=@captcha.png"
|
||||
|
||||
响应:
|
||||
{"type": "normal", "result": "A3B8", "raw": "A3B8", "time_ms": 12.3}
|
||||
uv run captcha serve --port 8080
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from urllib.error import URLError
|
||||
from urllib.parse import urlencode
|
||||
from urllib.request import Request as UrlRequest, urlopen
|
||||
|
||||
from config import CAPTCHA_TYPES, FUN_CAPTCHA_TASKS, SERVER_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app():
|
||||
ASYNC_TASK_TYPES = {
|
||||
"ImageToTextTask",
|
||||
"ImageToTextTaskM1",
|
||||
"ImageToTextTaskMuggle",
|
||||
"ImageToTextTaskProxyless",
|
||||
"CaptchaImageTask",
|
||||
"FunCaptcha",
|
||||
"FunCaptchaTask",
|
||||
"FunCaptchaTaskProxyless",
|
||||
}
|
||||
|
||||
|
||||
class ServiceError(Exception):
|
||||
"""服务内部可预期错误。"""
|
||||
|
||||
def __init__(self, code: str, description: str, status_code: int = 400):
|
||||
super().__init__(description)
|
||||
self.code = code
|
||||
self.description = description
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
task_id: str
|
||||
created_at: int
|
||||
expires_at: int
|
||||
client_ip: str | None = None
|
||||
task_type: str | None = None
|
||||
captcha_type: str | None = None
|
||||
question: str | None = None
|
||||
callback_url: str | None = None
|
||||
callback_attempts: int = 0
|
||||
callback_delivered_at: int | None = None
|
||||
callback_last_error: str | None = None
|
||||
status: str = "processing"
|
||||
result: dict | None = None
|
||||
error_code: str | None = None
|
||||
error_description: str | None = None
|
||||
completed_at: int | None = None
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""简单的内存任务管理器,适合单进程部署。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
solve_fn: Callable[[bytes, str | None, str | None], dict],
|
||||
ttl_seconds: int,
|
||||
max_workers: int,
|
||||
tasks_dir: str | Path,
|
||||
):
|
||||
self._solve_fn = solve_fn
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._tasks_dir = Path(tasks_dir)
|
||||
self._tasks_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="captcha-task",
|
||||
)
|
||||
self._tasks: dict[str, TaskRecord] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._load_tasks()
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
captcha_type: str | None,
|
||||
question: str | None = None,
|
||||
client_ip: str | None = None,
|
||||
task_type: str | None = None,
|
||||
callback_url: str | None = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
task_id = uuid.uuid4().hex
|
||||
record = TaskRecord(
|
||||
task_id=task_id,
|
||||
created_at=now,
|
||||
expires_at=now + self._ttl_seconds,
|
||||
client_ip=client_ip,
|
||||
task_type=task_type,
|
||||
captcha_type=captcha_type,
|
||||
question=question,
|
||||
callback_url=callback_url,
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_locked(now)
|
||||
self._tasks[task_id] = record
|
||||
self._persist_task_locked(record)
|
||||
|
||||
self._executor.submit(self._run_task, task_id, image_bytes, captcha_type, question)
|
||||
return task_id
|
||||
|
||||
def get_task(self, task_id: str) -> TaskRecord | None:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def stats(self) -> dict:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
processing = sum(1 for task in self._tasks.values() if task.status == "processing")
|
||||
ready = sum(1 for task in self._tasks.values() if task.status == "ready")
|
||||
failed = sum(1 for task in self._tasks.values() if task.status == "failed")
|
||||
return {
|
||||
"active": len(self._tasks),
|
||||
"processing": processing,
|
||||
"ready": ready,
|
||||
"failed": failed,
|
||||
"ttl_seconds": self._ttl_seconds,
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
self._executor.shutdown(wait=True, cancel_futures=False)
|
||||
|
||||
def _run_task(
|
||||
self,
|
||||
task_id: str,
|
||||
image_bytes: bytes,
|
||||
captcha_type: str | None,
|
||||
question: str | None,
|
||||
):
|
||||
try:
|
||||
result = self._solve_fn(image_bytes, captcha_type, question)
|
||||
except ServiceError as exc:
|
||||
task = self._mark_failed(task_id, exc.code, exc.description)
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
task = self._mark_failed(task_id, "ERROR_TASK_FAILED", str(exc))
|
||||
else:
|
||||
task = self._mark_ready(task_id, result)
|
||||
|
||||
if task and task.callback_url:
|
||||
self._send_task_callback(task)
|
||||
|
||||
def _mark_ready(self, task_id: str, result: dict) -> TaskRecord | None:
|
||||
with self._lock:
|
||||
record = self._tasks.get(task_id)
|
||||
if record is None:
|
||||
return None
|
||||
record.status = "ready"
|
||||
record.result = result
|
||||
record.completed_at = int(time.time())
|
||||
self._persist_task_locked(record)
|
||||
return record
|
||||
|
||||
def _mark_failed(self, task_id: str, code: str, description: str) -> TaskRecord | None:
|
||||
with self._lock:
|
||||
record = self._tasks.get(task_id)
|
||||
if record is None:
|
||||
return None
|
||||
record.status = "failed"
|
||||
record.error_code = code
|
||||
record.error_description = description
|
||||
record.completed_at = int(time.time())
|
||||
self._persist_task_locked(record)
|
||||
return record
|
||||
|
||||
def _cleanup_locked(self, now: int | None = None):
|
||||
now = now or int(time.time())
|
||||
expired_ids = [
|
||||
task_id
|
||||
for task_id, task in self._tasks.items()
|
||||
if task.expires_at <= now
|
||||
]
|
||||
for task_id in expired_ids:
|
||||
self._tasks.pop(task_id, None)
|
||||
self._delete_task_file(task_id)
|
||||
|
||||
def _update_callback_state(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
attempts: int | None = None,
|
||||
delivered_at: int | None = None,
|
||||
last_error: str | None = None,
|
||||
):
|
||||
with self._lock:
|
||||
record = self._tasks.get(task_id)
|
||||
if record is None:
|
||||
return
|
||||
if attempts is not None:
|
||||
record.callback_attempts = attempts
|
||||
if delivered_at is not None:
|
||||
record.callback_delivered_at = delivered_at
|
||||
record.callback_last_error = last_error
|
||||
self._persist_task_locked(record)
|
||||
|
||||
def _task_path(self, task_id: str) -> Path:
|
||||
return self._tasks_dir / f"{task_id}.json"
|
||||
|
||||
def _persist_task_locked(self, record: TaskRecord):
|
||||
path = self._task_path(record.task_id)
|
||||
tmp_path = path.with_suffix(".json.tmp")
|
||||
tmp_path.write_text(
|
||||
json.dumps(asdict(record), ensure_ascii=False, sort_keys=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _delete_task_file(self, task_id: str):
|
||||
self._task_path(task_id).unlink(missing_ok=True)
|
||||
|
||||
def _load_tasks(self):
|
||||
now = int(time.time())
|
||||
for path in sorted(self._tasks_dir.glob("*.json")):
|
||||
try:
|
||||
task_data = json.loads(path.read_text(encoding="utf-8"))
|
||||
task = TaskRecord(**task_data)
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.warning("skip invalid task file: path=%s err=%s", path, exc)
|
||||
path.unlink(missing_ok=True)
|
||||
continue
|
||||
|
||||
if task.expires_at <= now:
|
||||
path.unlink(missing_ok=True)
|
||||
continue
|
||||
|
||||
self._tasks[task.task_id] = task
|
||||
|
||||
def _send_task_callback(self, task: TaskRecord):
|
||||
max_retries = max(0, int(SERVER_CONFIG.get("callback_max_retries", 0)))
|
||||
delay = max(0.0, float(SERVER_CONFIG.get("callback_retry_delay_seconds", 0.0)))
|
||||
backoff = max(1.0, float(SERVER_CONFIG.get("callback_retry_backoff", 1.0)))
|
||||
payload = _build_task_callback_payload(task)
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
attempt_no = attempt + 1
|
||||
try:
|
||||
_post_callback(task.callback_url, payload)
|
||||
self._update_callback_state(
|
||||
task.task_id,
|
||||
attempts=attempt_no,
|
||||
delivered_at=int(time.time()),
|
||||
last_error=None,
|
||||
)
|
||||
return
|
||||
except (URLError, OSError, ValueError) as exc:
|
||||
self._update_callback_state(
|
||||
task.task_id,
|
||||
attempts=attempt_no,
|
||||
last_error=str(exc),
|
||||
)
|
||||
if attempt >= max_retries:
|
||||
logger.warning(
|
||||
"task callback failed: task_id=%s url=%s attempts=%s err=%s",
|
||||
task.task_id,
|
||||
task.callback_url,
|
||||
attempt_no,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"task callback retry: task_id=%s url=%s attempt=%s err=%s",
|
||||
task.task_id,
|
||||
task.callback_url,
|
||||
attempt_no,
|
||||
exc,
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
delay *= backoff
|
||||
|
||||
|
||||
def _task_success_payload(**extra) -> dict:
|
||||
payload = {"errorId": 0}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
|
||||
|
||||
def _task_error_payload(code: str, description: str, **extra) -> dict:
|
||||
payload = {
|
||||
"errorId": 1,
|
||||
"errorCode": code,
|
||||
"errorDescription": description,
|
||||
}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
|
||||
|
||||
def _normalize_captcha_type(captcha_type: str | None) -> str | None:
|
||||
if captcha_type in (None, ""):
|
||||
return None
|
||||
if captcha_type not in CAPTCHA_TYPES:
|
||||
raise ServiceError(
|
||||
"ERROR_BAD_CAPTCHA_TYPE",
|
||||
f"不支持的验证码类型: {captcha_type}",
|
||||
status_code=400,
|
||||
)
|
||||
return captcha_type
|
||||
|
||||
|
||||
def _normalize_question(question: str | None) -> str | None:
|
||||
if question in (None, ""):
|
||||
return None
|
||||
if question not in FUN_CAPTCHA_TASKS:
|
||||
raise ServiceError(
|
||||
"ERROR_TASK_QUESTION_NOT_SUPPORTED",
|
||||
f"不支持的专项任务 question: {question}",
|
||||
status_code=400,
|
||||
)
|
||||
return question
|
||||
|
||||
|
||||
def _decode_image_b64(encoded: str) -> bytes:
|
||||
if not encoded:
|
||||
raise ServiceError("ERROR_EMPTY_IMAGE", "空图片", status_code=400)
|
||||
|
||||
if encoded.startswith("data:") and "," in encoded:
|
||||
encoded = encoded.split(",", 1)[1]
|
||||
|
||||
encoded = "".join(encoded.split())
|
||||
encoded += "=" * (-len(encoded) % 4)
|
||||
|
||||
try:
|
||||
return base64.b64decode(encoded, altchars=b"-_", validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise ServiceError(
|
||||
"ERROR_BAD_IMAGE_BASE64",
|
||||
"base64 解码失败",
|
||||
status_code=400,
|
||||
) from exc
|
||||
|
||||
|
||||
def _validate_client_key(client_key: str | None):
|
||||
expected = SERVER_CONFIG.get("client_key")
|
||||
if not expected:
|
||||
return
|
||||
if client_key != expected:
|
||||
raise ServiceError(
|
||||
"ERROR_KEY_DOES_NOT_EXIST",
|
||||
"clientKey 无效",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
|
||||
def _build_task_solution(task: TaskRecord) -> dict:
|
||||
if task.result is None:
|
||||
return {}
|
||||
|
||||
if "objects" in task.result:
|
||||
objects = list(task.result["objects"])
|
||||
primary = objects[0] if objects else None
|
||||
solution = {
|
||||
"objects": objects,
|
||||
"answer": primary,
|
||||
"raw": task.result.get("raw", "" if primary is None else str(primary)),
|
||||
"timeMs": task.result["time_ms"],
|
||||
}
|
||||
if task.question:
|
||||
solution["question"] = task.question
|
||||
if primary is not None:
|
||||
solution["text"] = str(primary)
|
||||
return solution
|
||||
|
||||
return {
|
||||
"text": task.result["result"],
|
||||
"answer": task.result["result"],
|
||||
"raw": task.result["raw"],
|
||||
"captchaType": task.result["type"],
|
||||
"timeMs": task.result["time_ms"],
|
||||
}
|
||||
|
||||
|
||||
def _build_task_meta(task: TaskRecord) -> dict:
|
||||
payload = {
|
||||
"type": task.task_type,
|
||||
"captchaType": task.captcha_type,
|
||||
}
|
||||
if task.question is not None:
|
||||
payload["question"] = task.question
|
||||
return payload
|
||||
|
||||
|
||||
def _build_task_result_payload(task: TaskRecord) -> dict:
|
||||
if task.status == "processing":
|
||||
return _task_success_payload(
|
||||
taskId=task.task_id,
|
||||
status="processing",
|
||||
createTime=task.created_at,
|
||||
expiresAt=task.expires_at,
|
||||
task=_build_task_meta(task),
|
||||
callback={
|
||||
"configured": bool(task.callback_url),
|
||||
"url": task.callback_url,
|
||||
"attempts": task.callback_attempts,
|
||||
"delivered": False,
|
||||
"deliveredAt": task.callback_delivered_at,
|
||||
"lastError": task.callback_last_error,
|
||||
},
|
||||
)
|
||||
|
||||
if task.status == "failed":
|
||||
return _task_error_payload(
|
||||
task.error_code or "ERROR_TASK_FAILED",
|
||||
task.error_description or "任务执行失败",
|
||||
taskId=task.task_id,
|
||||
status="failed",
|
||||
ip=task.client_ip,
|
||||
createTime=task.created_at,
|
||||
endTime=task.completed_at,
|
||||
expiresAt=task.expires_at,
|
||||
task=_build_task_meta(task),
|
||||
callback={
|
||||
"configured": bool(task.callback_url),
|
||||
"url": task.callback_url,
|
||||
"attempts": task.callback_attempts,
|
||||
"delivered": task.callback_delivered_at is not None,
|
||||
"deliveredAt": task.callback_delivered_at,
|
||||
"lastError": task.callback_last_error,
|
||||
},
|
||||
)
|
||||
|
||||
return _task_success_payload(
|
||||
taskId=task.task_id,
|
||||
status="ready",
|
||||
solution=_build_task_solution(task),
|
||||
cost=f"{float(SERVER_CONFIG['task_cost']):.5f}",
|
||||
ip=task.client_ip,
|
||||
createTime=task.created_at,
|
||||
endTime=task.completed_at,
|
||||
expiresAt=task.expires_at,
|
||||
solveCount=1,
|
||||
task=_build_task_meta(task),
|
||||
callback={
|
||||
"configured": bool(task.callback_url),
|
||||
"url": task.callback_url,
|
||||
"attempts": task.callback_attempts,
|
||||
"delivered": task.callback_delivered_at is not None,
|
||||
"deliveredAt": task.callback_delivered_at,
|
||||
"lastError": task.callback_last_error,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_task_callback_payload(task: TaskRecord) -> dict[str, str]:
|
||||
payload = {
|
||||
"id": task.task_id,
|
||||
"taskId": task.task_id,
|
||||
"status": task.status,
|
||||
"errorId": "0" if task.status == "ready" else "1",
|
||||
}
|
||||
if task.status == "ready":
|
||||
solution = _build_task_solution(task)
|
||||
payload["code"] = str(solution.get("answer", ""))
|
||||
payload["answer"] = str(solution.get("answer", ""))
|
||||
payload["raw"] = str(solution.get("raw", ""))
|
||||
payload["timeMs"] = str(solution["timeMs"])
|
||||
payload["cost"] = f"{float(SERVER_CONFIG['task_cost']):.5f}"
|
||||
if "text" in solution:
|
||||
payload["text"] = str(solution["text"])
|
||||
if "captchaType" in solution:
|
||||
payload["captchaType"] = str(solution["captchaType"])
|
||||
if "objects" in solution:
|
||||
payload["objects"] = json.dumps(solution["objects"], ensure_ascii=False)
|
||||
if "question" in solution:
|
||||
payload["question"] = str(solution["question"])
|
||||
else:
|
||||
payload.update(
|
||||
{
|
||||
"errorCode": task.error_code or "ERROR_TASK_FAILED",
|
||||
"errorDescription": task.error_description or "任务执行失败",
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _sign_callback_payload(data: bytes, timestamp: str, secret: str) -> str:
|
||||
message = timestamp.encode("utf-8") + b"." + data
|
||||
return hmac.new(secret.encode("utf-8"), message, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def _build_callback_request(callback_url: str, payload: dict[str, str]) -> UrlRequest:
|
||||
data = urlencode(payload).encode("utf-8")
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
secret = SERVER_CONFIG.get("callback_signing_secret")
|
||||
if secret:
|
||||
timestamp = str(int(time.time()))
|
||||
headers["X-CaptchaBreaker-Timestamp"] = timestamp
|
||||
headers["X-CaptchaBreaker-Signature-Alg"] = "hmac-sha256"
|
||||
headers["X-CaptchaBreaker-Signature"] = _sign_callback_payload(data, timestamp, secret)
|
||||
|
||||
return UrlRequest(
|
||||
callback_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
|
||||
|
||||
def _post_callback(callback_url: str, payload: dict[str, str]):
|
||||
request = _build_callback_request(callback_url, payload)
|
||||
timeout = SERVER_CONFIG["callback_timeout_seconds"]
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
response.read()
|
||||
|
||||
|
||||
|
||||
def create_app(pipeline_factory=None, funcaptcha_factories=None):
|
||||
"""
|
||||
创建 FastAPI 应用实例(工厂函数)。
|
||||
|
||||
cli.py 的 cmd_serve 依赖此签名。
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, File, Query, UploadFile
|
||||
from fastapi import Body, FastAPI, File, Query, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from inference.fun_captcha import FunCaptchaRollballPipeline
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
|
||||
app = FastAPI(
|
||||
title="CaptchaBreaker",
|
||||
description="验证码识别多模型系统 - HTTP 推理服务",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
pipeline: Optional[CaptchaPipeline] = None
|
||||
load_pipeline_on_startup = pipeline_factory is None
|
||||
pipeline_factory = pipeline_factory or CaptchaPipeline
|
||||
fun_captcha_pipelines: dict[str, object] = {}
|
||||
load_fun_captcha_on_startup = funcaptcha_factories is None
|
||||
if funcaptcha_factories is None:
|
||||
funcaptcha_factories = {
|
||||
"4_3d_rollball_animals": lambda: FunCaptchaRollballPipeline(
|
||||
question="4_3d_rollball_animals"
|
||||
),
|
||||
}
|
||||
|
||||
# ---- 启动时加载模型 ----
|
||||
@app.on_event("startup")
|
||||
def _load_models():
|
||||
nonlocal pipeline
|
||||
try:
|
||||
pipeline = CaptchaPipeline()
|
||||
except FileNotFoundError:
|
||||
pipeline = None
|
||||
def _infer(image_bytes: bytes, captcha_type: str | None, question: str | None) -> dict:
|
||||
normalized_question = _normalize_question(question)
|
||||
normalized_type = None if normalized_question is not None else _normalize_captcha_type(captcha_type)
|
||||
|
||||
# ---- 请求体定义 ----
|
||||
class SolveRequest(BaseModel):
|
||||
image: str # base64 编码的图片
|
||||
type: Optional[str] = None # 指定类型可跳过分类
|
||||
|
||||
# ---- 通用推理逻辑 ----
|
||||
def _solve(image_bytes: bytes, captcha_type: Optional[str]) -> dict:
|
||||
if pipeline is None:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "模型未加载,请先训练并导出 ONNX 模型"},
|
||||
)
|
||||
if normalized_question is None:
|
||||
raise ServiceError(
|
||||
"ERROR_NO_MODELS_LOADED",
|
||||
"模型未加载,请先训练并导出 ONNX 模型",
|
||||
status_code=503,
|
||||
)
|
||||
if not image_bytes:
|
||||
return JSONResponse(status_code=400, content={"error": "空图片"})
|
||||
raise ServiceError("ERROR_EMPTY_IMAGE", "空图片", status_code=400)
|
||||
|
||||
try:
|
||||
result = pipeline.solve(image_bytes, captcha_type=captcha_type)
|
||||
except (RuntimeError, TypeError) as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"error": f"图片解析失败: {e}"})
|
||||
if normalized_question is not None:
|
||||
fun_pipeline = fun_captcha_pipelines.get(normalized_question)
|
||||
if fun_pipeline is None:
|
||||
raise ServiceError(
|
||||
"ERROR_NO_MODELS_LOADED",
|
||||
f"专项模型未加载: {normalized_question}",
|
||||
status_code=503,
|
||||
)
|
||||
result = fun_pipeline.solve(image_bytes)
|
||||
else:
|
||||
result = pipeline.solve(image_bytes, captcha_type=normalized_type)
|
||||
except ServiceError:
|
||||
raise
|
||||
except (RuntimeError, TypeError) as exc:
|
||||
raise ServiceError(
|
||||
"ERROR_INFERENCE_FAILED",
|
||||
str(exc),
|
||||
status_code=400,
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise ServiceError(
|
||||
"ERROR_IMAGE_DECODE_FAILED",
|
||||
f"图片解析失败: {exc}",
|
||||
status_code=400,
|
||||
) from exc
|
||||
|
||||
return {
|
||||
payload = {
|
||||
"type": result["type"],
|
||||
"result": result["result"],
|
||||
"raw": result["raw"],
|
||||
"time_ms": result["time_ms"],
|
||||
}
|
||||
if "question" in result:
|
||||
payload["question"] = result["question"]
|
||||
if "objects" in result:
|
||||
payload["objects"] = result["objects"]
|
||||
if "scores" in result:
|
||||
payload["scores"] = result["scores"]
|
||||
return payload
|
||||
|
||||
task_manager = TaskManager(
|
||||
solve_fn=_infer,
|
||||
ttl_seconds=SERVER_CONFIG["task_ttl_seconds"],
|
||||
max_workers=SERVER_CONFIG["task_workers"],
|
||||
tasks_dir=SERVER_CONFIG["tasks_dir"],
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
nonlocal pipeline
|
||||
if load_pipeline_on_startup:
|
||||
try:
|
||||
pipeline = pipeline_factory()
|
||||
except FileNotFoundError:
|
||||
pipeline = None
|
||||
if load_fun_captcha_on_startup:
|
||||
for question, factory in funcaptcha_factories.items():
|
||||
try:
|
||||
fun_captcha_pipelines[question] = factory()
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task_manager.shutdown()
|
||||
|
||||
app = FastAPI(
|
||||
title="CaptchaBreaker",
|
||||
description="验证码识别多模型系统 - HTTP 推理服务",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.state.task_manager = task_manager
|
||||
|
||||
if not load_pipeline_on_startup:
|
||||
pipeline = pipeline_factory()
|
||||
if not load_fun_captcha_on_startup:
|
||||
fun_captcha_pipelines = {
|
||||
question: factory()
|
||||
for question, factory in funcaptcha_factories.items()
|
||||
}
|
||||
|
||||
# ---- 请求体定义 ----
|
||||
class SolveRequest(BaseModel):
|
||||
image: str
|
||||
type: Optional[str] = None
|
||||
question: Optional[str] = None
|
||||
|
||||
class AsyncTaskRequest(BaseModel):
|
||||
type: str
|
||||
body: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
captchaType: Optional[str] = None
|
||||
question: Optional[str] = None
|
||||
|
||||
class CreateTaskRequest(BaseModel):
|
||||
clientKey: Optional[str] = None
|
||||
callbackUrl: Optional[str] = None
|
||||
softId: Optional[int] = None
|
||||
languagePool: Optional[str] = None
|
||||
task: AsyncTaskRequest
|
||||
|
||||
class GetTaskResultRequest(BaseModel):
|
||||
clientKey: Optional[str] = None
|
||||
taskId: str
|
||||
|
||||
class BalanceRequest(BaseModel):
|
||||
clientKey: Optional[str] = None
|
||||
|
||||
# ---- 路由 ----
|
||||
@app.get("/health")
|
||||
@app.get("/api/v1/health")
|
||||
def health():
|
||||
models_loaded = pipeline is not None
|
||||
models_loaded = pipeline is not None or bool(fun_captcha_pipelines)
|
||||
return {
|
||||
"status": "ok" if models_loaded else "no_models",
|
||||
"models_loaded": models_loaded,
|
||||
"async_tasks": task_manager.stats(),
|
||||
"client_key_required": bool(SERVER_CONFIG.get("client_key")),
|
||||
"supported_task_types": sorted(ASYNC_TASK_TYPES),
|
||||
"supported_task_questions": sorted(FUN_CAPTCHA_TASKS),
|
||||
}
|
||||
|
||||
@app.post("/solve")
|
||||
@app.post("/api/v1/solve")
|
||||
async def solve_base64(req: SolveRequest):
|
||||
"""JSON 请求,图片通过 base64 传输。"""
|
||||
try:
|
||||
image_bytes = base64.b64decode(req.image)
|
||||
except Exception:
|
||||
image_bytes = _decode_image_b64(req.image)
|
||||
return _infer(image_bytes, req.type, getattr(req, "question", None))
|
||||
except ServiceError as exc:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "base64 解码失败"},
|
||||
status_code=exc.status_code,
|
||||
content={"error": exc.description, "error_code": exc.code},
|
||||
)
|
||||
return _solve(image_bytes, req.type)
|
||||
|
||||
@app.post("/solve/upload")
|
||||
@app.post("/api/v1/solve/upload")
|
||||
async def solve_upload(
|
||||
image: UploadFile = File(...),
|
||||
type: Optional[str] = Query(None, description="指定类型跳过分类"),
|
||||
question: Optional[str] = Query(None, description="专项 question"),
|
||||
):
|
||||
"""multipart 文件上传。"""
|
||||
data = await image.read()
|
||||
return _solve(data, type)
|
||||
try:
|
||||
return _infer(await image.read(), type, question)
|
||||
except ServiceError as exc:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": exc.description, "error_code": exc.code},
|
||||
)
|
||||
|
||||
@app.post("/createTask")
|
||||
@app.post("/api/v1/createTask")
|
||||
async def create_task(req: CreateTaskRequest, request: Request):
|
||||
"""创建异步识别任务,返回 taskId。"""
|
||||
task = req.task
|
||||
try:
|
||||
_validate_client_key(req.clientKey)
|
||||
except ServiceError as exc:
|
||||
return _task_error_payload(exc.code, exc.description)
|
||||
|
||||
if task.type not in ASYNC_TASK_TYPES:
|
||||
return _task_error_payload(
|
||||
"ERROR_TASK_TYPE_NOT_SUPPORTED",
|
||||
f"不支持的任务类型: {task.type}",
|
||||
)
|
||||
|
||||
try:
|
||||
question = _normalize_question(getattr(task, "question", None))
|
||||
if question is None:
|
||||
_normalize_captcha_type(task.captchaType)
|
||||
image_bytes = _decode_image_b64(task.body or task.image or "")
|
||||
if question is None and pipeline is None:
|
||||
raise ServiceError(
|
||||
"ERROR_NO_MODELS_LOADED",
|
||||
"模型未加载,请先训练并导出 ONNX 模型",
|
||||
status_code=503,
|
||||
)
|
||||
if question is not None and question not in fun_captcha_pipelines:
|
||||
raise ServiceError(
|
||||
"ERROR_NO_MODELS_LOADED",
|
||||
f"专项模型未加载: {question}",
|
||||
status_code=503,
|
||||
)
|
||||
except ServiceError as exc:
|
||||
return _task_error_payload(exc.code, exc.description)
|
||||
|
||||
client = getattr(request, "client", None)
|
||||
client_ip = getattr(client, "host", None)
|
||||
task_id = task_manager.create_task(
|
||||
image_bytes,
|
||||
task.captchaType,
|
||||
question=question,
|
||||
client_ip=client_ip,
|
||||
task_type=task.type,
|
||||
callback_url=getattr(req, "callbackUrl", None),
|
||||
)
|
||||
task = task_manager.get_task(task_id)
|
||||
return _task_success_payload(
|
||||
taskId=task_id,
|
||||
status="processing",
|
||||
createTime=task.created_at if task else int(time.time()),
|
||||
expiresAt=task.expires_at if task else None,
|
||||
)
|
||||
|
||||
@app.post("/getTaskResult")
|
||||
@app.post("/api/v1/getTaskResult")
|
||||
async def get_task_result(req: GetTaskResultRequest):
|
||||
"""按 taskId 查询任务状态与结果。"""
|
||||
try:
|
||||
_validate_client_key(req.clientKey)
|
||||
except ServiceError as exc:
|
||||
return _task_error_payload(exc.code, exc.description, taskId=req.taskId)
|
||||
|
||||
task = task_manager.get_task(str(req.taskId))
|
||||
if task is None:
|
||||
return _task_error_payload(
|
||||
"ERROR_TASK_NOT_FOUND",
|
||||
f"任务不存在或已过期: {req.taskId}",
|
||||
taskId=req.taskId,
|
||||
)
|
||||
|
||||
return _build_task_result_payload(task)
|
||||
|
||||
@app.post("/getBalance")
|
||||
@app.post("/api/v1/getBalance")
|
||||
async def get_balance(req: BalanceRequest | None = Body(default=None)):
|
||||
try:
|
||||
_validate_client_key(getattr(req, "clientKey", None))
|
||||
except ServiceError as exc:
|
||||
return _task_error_payload(exc.code, exc.description)
|
||||
return _task_success_payload(balance=SERVER_CONFIG["balance"])
|
||||
|
||||
return app
|
||||
|
||||
Reference in New Issue
Block a user