823 lines
27 KiB
Python
823 lines
27 KiB
Python
"""
|
|
FastAPI HTTP 推理服务 (纯推理,不依赖 torch/训练代码)
|
|
|
|
仅依赖: fastapi, uvicorn, python-multipart, onnxruntime, pillow, numpy
|
|
|
|
API:
|
|
POST /solve JSON base64 图片识别 (同步)
|
|
POST /solve/upload multipart 文件上传识别 (同步)
|
|
POST /createTask 创建异步识别任务
|
|
POST /getTaskResult 查询异步任务结果
|
|
POST /getBalance 查询本地服务余额占位值
|
|
POST /api/v1/* 兼容别名
|
|
GET /health 健康检查
|
|
GET /api/v1/health 健康检查兼容别名
|
|
|
|
启动:
|
|
uv sync --extra server
|
|
uv run captcha serve --port 8080
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import binascii
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
from urllib.error import URLError
|
|
from urllib.parse import urlencode
|
|
from urllib.request import Request as UrlRequest, urlopen
|
|
|
|
from config import CAPTCHA_TYPES, FUN_CAPTCHA_TASKS, SERVER_CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
ASYNC_TASK_TYPES = {
|
|
"ImageToTextTask",
|
|
"ImageToTextTaskM1",
|
|
"ImageToTextTaskMuggle",
|
|
"ImageToTextTaskProxyless",
|
|
"CaptchaImageTask",
|
|
"FunCaptcha",
|
|
"FunCaptchaTask",
|
|
"FunCaptchaTaskProxyless",
|
|
}
|
|
|
|
|
|
class ServiceError(Exception):
|
|
"""服务内部可预期错误。"""
|
|
|
|
def __init__(self, code: str, description: str, status_code: int = 400):
|
|
super().__init__(description)
|
|
self.code = code
|
|
self.description = description
|
|
self.status_code = status_code
|
|
|
|
|
|
@dataclass
|
|
class TaskRecord:
|
|
task_id: str
|
|
created_at: int
|
|
expires_at: int
|
|
client_ip: str | None = None
|
|
task_type: str | None = None
|
|
captcha_type: str | None = None
|
|
question: str | None = None
|
|
callback_url: str | None = None
|
|
callback_attempts: int = 0
|
|
callback_delivered_at: int | None = None
|
|
callback_last_error: str | None = None
|
|
status: str = "processing"
|
|
result: dict | None = None
|
|
error_code: str | None = None
|
|
error_description: str | None = None
|
|
completed_at: int | None = None
|
|
|
|
|
|
class TaskManager:
|
|
"""简单的内存任务管理器,适合单进程部署。"""
|
|
|
|
def __init__(
|
|
self,
|
|
solve_fn: Callable[[bytes, str | None, str | None], dict],
|
|
ttl_seconds: int,
|
|
max_workers: int,
|
|
tasks_dir: str | Path,
|
|
):
|
|
self._solve_fn = solve_fn
|
|
self._ttl_seconds = ttl_seconds
|
|
self._tasks_dir = Path(tasks_dir)
|
|
self._tasks_dir.mkdir(parents=True, exist_ok=True)
|
|
self._executor = ThreadPoolExecutor(
|
|
max_workers=max_workers,
|
|
thread_name_prefix="captcha-task",
|
|
)
|
|
self._tasks: dict[str, TaskRecord] = {}
|
|
self._lock = threading.Lock()
|
|
self._load_tasks()
|
|
|
|
def create_task(
|
|
self,
|
|
image_bytes: bytes,
|
|
captcha_type: str | None,
|
|
question: str | None = None,
|
|
client_ip: str | None = None,
|
|
task_type: str | None = None,
|
|
callback_url: str | None = None,
|
|
) -> str:
|
|
now = int(time.time())
|
|
task_id = uuid.uuid4().hex
|
|
record = TaskRecord(
|
|
task_id=task_id,
|
|
created_at=now,
|
|
expires_at=now + self._ttl_seconds,
|
|
client_ip=client_ip,
|
|
task_type=task_type,
|
|
captcha_type=captcha_type,
|
|
question=question,
|
|
callback_url=callback_url,
|
|
)
|
|
with self._lock:
|
|
self._cleanup_locked(now)
|
|
self._tasks[task_id] = record
|
|
self._persist_task_locked(record)
|
|
|
|
self._executor.submit(self._run_task, task_id, image_bytes, captcha_type, question)
|
|
return task_id
|
|
|
|
def get_task(self, task_id: str) -> TaskRecord | None:
|
|
with self._lock:
|
|
self._cleanup_locked()
|
|
return self._tasks.get(task_id)
|
|
|
|
def stats(self) -> dict:
|
|
with self._lock:
|
|
self._cleanup_locked()
|
|
processing = sum(1 for task in self._tasks.values() if task.status == "processing")
|
|
ready = sum(1 for task in self._tasks.values() if task.status == "ready")
|
|
failed = sum(1 for task in self._tasks.values() if task.status == "failed")
|
|
return {
|
|
"active": len(self._tasks),
|
|
"processing": processing,
|
|
"ready": ready,
|
|
"failed": failed,
|
|
"ttl_seconds": self._ttl_seconds,
|
|
}
|
|
|
|
def shutdown(self):
|
|
self._executor.shutdown(wait=True, cancel_futures=False)
|
|
|
|
def _run_task(
|
|
self,
|
|
task_id: str,
|
|
image_bytes: bytes,
|
|
captcha_type: str | None,
|
|
question: str | None,
|
|
):
|
|
try:
|
|
result = self._solve_fn(image_bytes, captcha_type, question)
|
|
except ServiceError as exc:
|
|
task = self._mark_failed(task_id, exc.code, exc.description)
|
|
except Exception as exc: # pragma: no cover - 防御性兜底
|
|
task = self._mark_failed(task_id, "ERROR_TASK_FAILED", str(exc))
|
|
else:
|
|
task = self._mark_ready(task_id, result)
|
|
|
|
if task and task.callback_url:
|
|
self._send_task_callback(task)
|
|
|
|
def _mark_ready(self, task_id: str, result: dict) -> TaskRecord | None:
|
|
with self._lock:
|
|
record = self._tasks.get(task_id)
|
|
if record is None:
|
|
return None
|
|
record.status = "ready"
|
|
record.result = result
|
|
record.completed_at = int(time.time())
|
|
self._persist_task_locked(record)
|
|
return record
|
|
|
|
def _mark_failed(self, task_id: str, code: str, description: str) -> TaskRecord | None:
|
|
with self._lock:
|
|
record = self._tasks.get(task_id)
|
|
if record is None:
|
|
return None
|
|
record.status = "failed"
|
|
record.error_code = code
|
|
record.error_description = description
|
|
record.completed_at = int(time.time())
|
|
self._persist_task_locked(record)
|
|
return record
|
|
|
|
def _cleanup_locked(self, now: int | None = None):
|
|
now = now or int(time.time())
|
|
expired_ids = [
|
|
task_id
|
|
for task_id, task in self._tasks.items()
|
|
if task.expires_at <= now
|
|
]
|
|
for task_id in expired_ids:
|
|
self._tasks.pop(task_id, None)
|
|
self._delete_task_file(task_id)
|
|
|
|
def _update_callback_state(
|
|
self,
|
|
task_id: str,
|
|
*,
|
|
attempts: int | None = None,
|
|
delivered_at: int | None = None,
|
|
last_error: str | None = None,
|
|
):
|
|
with self._lock:
|
|
record = self._tasks.get(task_id)
|
|
if record is None:
|
|
return
|
|
if attempts is not None:
|
|
record.callback_attempts = attempts
|
|
if delivered_at is not None:
|
|
record.callback_delivered_at = delivered_at
|
|
record.callback_last_error = last_error
|
|
self._persist_task_locked(record)
|
|
|
|
def _task_path(self, task_id: str) -> Path:
|
|
return self._tasks_dir / f"{task_id}.json"
|
|
|
|
def _persist_task_locked(self, record: TaskRecord):
|
|
path = self._task_path(record.task_id)
|
|
tmp_path = path.with_suffix(".json.tmp")
|
|
tmp_path.write_text(
|
|
json.dumps(asdict(record), ensure_ascii=False, sort_keys=True),
|
|
encoding="utf-8",
|
|
)
|
|
tmp_path.replace(path)
|
|
|
|
def _delete_task_file(self, task_id: str):
|
|
self._task_path(task_id).unlink(missing_ok=True)
|
|
|
|
def _load_tasks(self):
|
|
now = int(time.time())
|
|
for path in sorted(self._tasks_dir.glob("*.json")):
|
|
try:
|
|
task_data = json.loads(path.read_text(encoding="utf-8"))
|
|
task = TaskRecord(**task_data)
|
|
except Exception as exc: # pragma: no cover - 防御性兜底
|
|
logger.warning("skip invalid task file: path=%s err=%s", path, exc)
|
|
path.unlink(missing_ok=True)
|
|
continue
|
|
|
|
if task.expires_at <= now:
|
|
path.unlink(missing_ok=True)
|
|
continue
|
|
|
|
self._tasks[task.task_id] = task
|
|
|
|
def _send_task_callback(self, task: TaskRecord):
|
|
max_retries = max(0, int(SERVER_CONFIG.get("callback_max_retries", 0)))
|
|
delay = max(0.0, float(SERVER_CONFIG.get("callback_retry_delay_seconds", 0.0)))
|
|
backoff = max(1.0, float(SERVER_CONFIG.get("callback_retry_backoff", 1.0)))
|
|
payload = _build_task_callback_payload(task)
|
|
|
|
for attempt in range(max_retries + 1):
|
|
attempt_no = attempt + 1
|
|
try:
|
|
_post_callback(task.callback_url, payload)
|
|
self._update_callback_state(
|
|
task.task_id,
|
|
attempts=attempt_no,
|
|
delivered_at=int(time.time()),
|
|
last_error=None,
|
|
)
|
|
return
|
|
except (URLError, OSError, ValueError) as exc:
|
|
self._update_callback_state(
|
|
task.task_id,
|
|
attempts=attempt_no,
|
|
last_error=str(exc),
|
|
)
|
|
if attempt >= max_retries:
|
|
logger.warning(
|
|
"task callback failed: task_id=%s url=%s attempts=%s err=%s",
|
|
task.task_id,
|
|
task.callback_url,
|
|
attempt_no,
|
|
exc,
|
|
)
|
|
return
|
|
|
|
logger.warning(
|
|
"task callback retry: task_id=%s url=%s attempt=%s err=%s",
|
|
task.task_id,
|
|
task.callback_url,
|
|
attempt_no,
|
|
exc,
|
|
)
|
|
if delay > 0:
|
|
time.sleep(delay)
|
|
delay *= backoff
|
|
|
|
|
|
def _task_success_payload(**extra) -> dict:
|
|
payload = {"errorId": 0}
|
|
payload.update(extra)
|
|
return payload
|
|
|
|
|
|
def _task_error_payload(code: str, description: str, **extra) -> dict:
|
|
payload = {
|
|
"errorId": 1,
|
|
"errorCode": code,
|
|
"errorDescription": description,
|
|
}
|
|
payload.update(extra)
|
|
return payload
|
|
|
|
|
|
def _normalize_captcha_type(captcha_type: str | None) -> str | None:
|
|
if captcha_type in (None, ""):
|
|
return None
|
|
if captcha_type not in CAPTCHA_TYPES:
|
|
raise ServiceError(
|
|
"ERROR_BAD_CAPTCHA_TYPE",
|
|
f"不支持的验证码类型: {captcha_type}",
|
|
status_code=400,
|
|
)
|
|
return captcha_type
|
|
|
|
|
|
def _normalize_question(question: str | None) -> str | None:
|
|
if question in (None, ""):
|
|
return None
|
|
if question not in FUN_CAPTCHA_TASKS:
|
|
raise ServiceError(
|
|
"ERROR_TASK_QUESTION_NOT_SUPPORTED",
|
|
f"不支持的专项任务 question: {question}",
|
|
status_code=400,
|
|
)
|
|
return question
|
|
|
|
|
|
def _decode_image_b64(encoded: str) -> bytes:
|
|
if not encoded:
|
|
raise ServiceError("ERROR_EMPTY_IMAGE", "空图片", status_code=400)
|
|
|
|
if encoded.startswith("data:") and "," in encoded:
|
|
encoded = encoded.split(",", 1)[1]
|
|
|
|
encoded = "".join(encoded.split())
|
|
encoded += "=" * (-len(encoded) % 4)
|
|
|
|
try:
|
|
return base64.b64decode(encoded, altchars=b"-_", validate=True)
|
|
except (binascii.Error, ValueError) as exc:
|
|
raise ServiceError(
|
|
"ERROR_BAD_IMAGE_BASE64",
|
|
"base64 解码失败",
|
|
status_code=400,
|
|
) from exc
|
|
|
|
|
|
def _validate_client_key(client_key: str | None):
|
|
expected = SERVER_CONFIG.get("client_key")
|
|
if not expected:
|
|
return
|
|
if client_key != expected:
|
|
raise ServiceError(
|
|
"ERROR_KEY_DOES_NOT_EXIST",
|
|
"clientKey 无效",
|
|
status_code=401,
|
|
)
|
|
|
|
|
|
def _build_task_solution(task: TaskRecord) -> dict:
|
|
if task.result is None:
|
|
return {}
|
|
|
|
if "objects" in task.result:
|
|
objects = list(task.result["objects"])
|
|
primary = objects[0] if objects else None
|
|
solution = {
|
|
"objects": objects,
|
|
"answer": primary,
|
|
"raw": task.result.get("raw", "" if primary is None else str(primary)),
|
|
"timeMs": task.result["time_ms"],
|
|
}
|
|
if task.question:
|
|
solution["question"] = task.question
|
|
if primary is not None:
|
|
solution["text"] = str(primary)
|
|
return solution
|
|
|
|
return {
|
|
"text": task.result["result"],
|
|
"answer": task.result["result"],
|
|
"raw": task.result["raw"],
|
|
"captchaType": task.result["type"],
|
|
"timeMs": task.result["time_ms"],
|
|
}
|
|
|
|
|
|
def _build_task_meta(task: TaskRecord) -> dict:
|
|
payload = {
|
|
"type": task.task_type,
|
|
"captchaType": task.captcha_type,
|
|
}
|
|
if task.question is not None:
|
|
payload["question"] = task.question
|
|
return payload
|
|
|
|
|
|
def _build_task_result_payload(task: TaskRecord) -> dict:
|
|
if task.status == "processing":
|
|
return _task_success_payload(
|
|
taskId=task.task_id,
|
|
status="processing",
|
|
createTime=task.created_at,
|
|
expiresAt=task.expires_at,
|
|
task=_build_task_meta(task),
|
|
callback={
|
|
"configured": bool(task.callback_url),
|
|
"url": task.callback_url,
|
|
"attempts": task.callback_attempts,
|
|
"delivered": False,
|
|
"deliveredAt": task.callback_delivered_at,
|
|
"lastError": task.callback_last_error,
|
|
},
|
|
)
|
|
|
|
if task.status == "failed":
|
|
return _task_error_payload(
|
|
task.error_code or "ERROR_TASK_FAILED",
|
|
task.error_description or "任务执行失败",
|
|
taskId=task.task_id,
|
|
status="failed",
|
|
ip=task.client_ip,
|
|
createTime=task.created_at,
|
|
endTime=task.completed_at,
|
|
expiresAt=task.expires_at,
|
|
task=_build_task_meta(task),
|
|
callback={
|
|
"configured": bool(task.callback_url),
|
|
"url": task.callback_url,
|
|
"attempts": task.callback_attempts,
|
|
"delivered": task.callback_delivered_at is not None,
|
|
"deliveredAt": task.callback_delivered_at,
|
|
"lastError": task.callback_last_error,
|
|
},
|
|
)
|
|
|
|
return _task_success_payload(
|
|
taskId=task.task_id,
|
|
status="ready",
|
|
solution=_build_task_solution(task),
|
|
cost=f"{float(SERVER_CONFIG['task_cost']):.5f}",
|
|
ip=task.client_ip,
|
|
createTime=task.created_at,
|
|
endTime=task.completed_at,
|
|
expiresAt=task.expires_at,
|
|
solveCount=1,
|
|
task=_build_task_meta(task),
|
|
callback={
|
|
"configured": bool(task.callback_url),
|
|
"url": task.callback_url,
|
|
"attempts": task.callback_attempts,
|
|
"delivered": task.callback_delivered_at is not None,
|
|
"deliveredAt": task.callback_delivered_at,
|
|
"lastError": task.callback_last_error,
|
|
},
|
|
)
|
|
|
|
|
|
def _build_task_callback_payload(task: TaskRecord) -> dict[str, str]:
|
|
payload = {
|
|
"id": task.task_id,
|
|
"taskId": task.task_id,
|
|
"status": task.status,
|
|
"errorId": "0" if task.status == "ready" else "1",
|
|
}
|
|
if task.status == "ready":
|
|
solution = _build_task_solution(task)
|
|
payload["code"] = str(solution.get("answer", ""))
|
|
payload["answer"] = str(solution.get("answer", ""))
|
|
payload["raw"] = str(solution.get("raw", ""))
|
|
payload["timeMs"] = str(solution["timeMs"])
|
|
payload["cost"] = f"{float(SERVER_CONFIG['task_cost']):.5f}"
|
|
if "text" in solution:
|
|
payload["text"] = str(solution["text"])
|
|
if "captchaType" in solution:
|
|
payload["captchaType"] = str(solution["captchaType"])
|
|
if "objects" in solution:
|
|
payload["objects"] = json.dumps(solution["objects"], ensure_ascii=False)
|
|
if "question" in solution:
|
|
payload["question"] = str(solution["question"])
|
|
else:
|
|
payload.update(
|
|
{
|
|
"errorCode": task.error_code or "ERROR_TASK_FAILED",
|
|
"errorDescription": task.error_description or "任务执行失败",
|
|
}
|
|
)
|
|
return payload
|
|
|
|
|
|
def _sign_callback_payload(data: bytes, timestamp: str, secret: str) -> str:
|
|
message = timestamp.encode("utf-8") + b"." + data
|
|
return hmac.new(secret.encode("utf-8"), message, hashlib.sha256).hexdigest()
|
|
|
|
|
|
def _build_callback_request(callback_url: str, payload: dict[str, str]) -> UrlRequest:
|
|
data = urlencode(payload).encode("utf-8")
|
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
|
secret = SERVER_CONFIG.get("callback_signing_secret")
|
|
if secret:
|
|
timestamp = str(int(time.time()))
|
|
headers["X-CaptchaBreaker-Timestamp"] = timestamp
|
|
headers["X-CaptchaBreaker-Signature-Alg"] = "hmac-sha256"
|
|
headers["X-CaptchaBreaker-Signature"] = _sign_callback_payload(data, timestamp, secret)
|
|
|
|
return UrlRequest(
|
|
callback_url,
|
|
data=data,
|
|
headers=headers,
|
|
method="POST",
|
|
)
|
|
|
|
|
|
def _post_callback(callback_url: str, payload: dict[str, str]):
|
|
request = _build_callback_request(callback_url, payload)
|
|
timeout = SERVER_CONFIG["callback_timeout_seconds"]
|
|
with urlopen(request, timeout=timeout) as response:
|
|
response.read()
|
|
|
|
|
|
|
|
def create_app(pipeline_factory=None, funcaptcha_factories=None):
|
|
"""
|
|
创建 FastAPI 应用实例(工厂函数)。
|
|
|
|
cli.py 的 cmd_serve 依赖此签名。
|
|
"""
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
|
|
from fastapi import Body, FastAPI, File, Query, Request, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
|
|
from inference.fun_captcha import FunCaptchaRollballPipeline
|
|
from inference.pipeline import CaptchaPipeline
|
|
|
|
pipeline: Optional[CaptchaPipeline] = None
|
|
load_pipeline_on_startup = pipeline_factory is None
|
|
pipeline_factory = pipeline_factory or CaptchaPipeline
|
|
fun_captcha_pipelines: dict[str, object] = {}
|
|
load_fun_captcha_on_startup = funcaptcha_factories is None
|
|
if funcaptcha_factories is None:
|
|
funcaptcha_factories = {
|
|
"4_3d_rollball_animals": lambda: FunCaptchaRollballPipeline(
|
|
question="4_3d_rollball_animals"
|
|
),
|
|
}
|
|
|
|
def _infer(image_bytes: bytes, captcha_type: str | None, question: str | None) -> dict:
|
|
normalized_question = _normalize_question(question)
|
|
normalized_type = None if normalized_question is not None else _normalize_captcha_type(captcha_type)
|
|
|
|
if pipeline is None:
|
|
if normalized_question is None:
|
|
raise ServiceError(
|
|
"ERROR_NO_MODELS_LOADED",
|
|
"模型未加载,请先训练并导出 ONNX 模型",
|
|
status_code=503,
|
|
)
|
|
if not image_bytes:
|
|
raise ServiceError("ERROR_EMPTY_IMAGE", "空图片", status_code=400)
|
|
|
|
try:
|
|
if normalized_question is not None:
|
|
fun_pipeline = fun_captcha_pipelines.get(normalized_question)
|
|
if fun_pipeline is None:
|
|
raise ServiceError(
|
|
"ERROR_NO_MODELS_LOADED",
|
|
f"专项模型未加载: {normalized_question}",
|
|
status_code=503,
|
|
)
|
|
result = fun_pipeline.solve(image_bytes)
|
|
else:
|
|
result = pipeline.solve(image_bytes, captcha_type=normalized_type)
|
|
except ServiceError:
|
|
raise
|
|
except (RuntimeError, TypeError) as exc:
|
|
raise ServiceError(
|
|
"ERROR_INFERENCE_FAILED",
|
|
str(exc),
|
|
status_code=400,
|
|
) from exc
|
|
except Exception as exc:
|
|
raise ServiceError(
|
|
"ERROR_IMAGE_DECODE_FAILED",
|
|
f"图片解析失败: {exc}",
|
|
status_code=400,
|
|
) from exc
|
|
|
|
payload = {
|
|
"type": result["type"],
|
|
"result": result["result"],
|
|
"raw": result["raw"],
|
|
"time_ms": result["time_ms"],
|
|
}
|
|
if "question" in result:
|
|
payload["question"] = result["question"]
|
|
if "objects" in result:
|
|
payload["objects"] = result["objects"]
|
|
if "scores" in result:
|
|
payload["scores"] = result["scores"]
|
|
return payload
|
|
|
|
task_manager = TaskManager(
|
|
solve_fn=_infer,
|
|
ttl_seconds=SERVER_CONFIG["task_ttl_seconds"],
|
|
max_workers=SERVER_CONFIG["task_workers"],
|
|
tasks_dir=SERVER_CONFIG["tasks_dir"],
|
|
)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_app: FastAPI):
|
|
nonlocal pipeline
|
|
if load_pipeline_on_startup:
|
|
try:
|
|
pipeline = pipeline_factory()
|
|
except FileNotFoundError:
|
|
pipeline = None
|
|
if load_fun_captcha_on_startup:
|
|
for question, factory in funcaptcha_factories.items():
|
|
try:
|
|
fun_captcha_pipelines[question] = factory()
|
|
except FileNotFoundError:
|
|
continue
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
task_manager.shutdown()
|
|
|
|
app = FastAPI(
|
|
title="CaptchaBreaker",
|
|
description="验证码识别多模型系统 - HTTP 推理服务",
|
|
version="0.2.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
app.state.task_manager = task_manager
|
|
|
|
if not load_pipeline_on_startup:
|
|
pipeline = pipeline_factory()
|
|
if not load_fun_captcha_on_startup:
|
|
fun_captcha_pipelines = {
|
|
question: factory()
|
|
for question, factory in funcaptcha_factories.items()
|
|
}
|
|
|
|
# ---- 请求体定义 ----
|
|
class SolveRequest(BaseModel):
|
|
image: str
|
|
type: Optional[str] = None
|
|
question: Optional[str] = None
|
|
|
|
class AsyncTaskRequest(BaseModel):
|
|
type: str
|
|
body: Optional[str] = None
|
|
image: Optional[str] = None
|
|
captchaType: Optional[str] = None
|
|
question: Optional[str] = None
|
|
|
|
class CreateTaskRequest(BaseModel):
|
|
clientKey: Optional[str] = None
|
|
callbackUrl: Optional[str] = None
|
|
softId: Optional[int] = None
|
|
languagePool: Optional[str] = None
|
|
task: AsyncTaskRequest
|
|
|
|
class GetTaskResultRequest(BaseModel):
|
|
clientKey: Optional[str] = None
|
|
taskId: str
|
|
|
|
class BalanceRequest(BaseModel):
|
|
clientKey: Optional[str] = None
|
|
|
|
# ---- 路由 ----
|
|
@app.get("/health")
|
|
@app.get("/api/v1/health")
|
|
def health():
|
|
models_loaded = pipeline is not None or bool(fun_captcha_pipelines)
|
|
return {
|
|
"status": "ok" if models_loaded else "no_models",
|
|
"models_loaded": models_loaded,
|
|
"async_tasks": task_manager.stats(),
|
|
"client_key_required": bool(SERVER_CONFIG.get("client_key")),
|
|
"supported_task_types": sorted(ASYNC_TASK_TYPES),
|
|
"supported_task_questions": sorted(FUN_CAPTCHA_TASKS),
|
|
}
|
|
|
|
@app.post("/solve")
|
|
@app.post("/api/v1/solve")
|
|
async def solve_base64(req: SolveRequest):
|
|
"""JSON 请求,图片通过 base64 传输。"""
|
|
try:
|
|
image_bytes = _decode_image_b64(req.image)
|
|
return _infer(image_bytes, req.type, getattr(req, "question", None))
|
|
except ServiceError as exc:
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"error": exc.description, "error_code": exc.code},
|
|
)
|
|
|
|
@app.post("/solve/upload")
|
|
@app.post("/api/v1/solve/upload")
|
|
async def solve_upload(
|
|
image: UploadFile = File(...),
|
|
type: Optional[str] = Query(None, description="指定类型跳过分类"),
|
|
question: Optional[str] = Query(None, description="专项 question"),
|
|
):
|
|
"""multipart 文件上传。"""
|
|
try:
|
|
return _infer(await image.read(), type, question)
|
|
except ServiceError as exc:
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"error": exc.description, "error_code": exc.code},
|
|
)
|
|
|
|
@app.post("/createTask")
|
|
@app.post("/api/v1/createTask")
|
|
async def create_task(req: CreateTaskRequest, request: Request):
|
|
"""创建异步识别任务,返回 taskId。"""
|
|
task = req.task
|
|
try:
|
|
_validate_client_key(req.clientKey)
|
|
except ServiceError as exc:
|
|
return _task_error_payload(exc.code, exc.description)
|
|
|
|
if task.type not in ASYNC_TASK_TYPES:
|
|
return _task_error_payload(
|
|
"ERROR_TASK_TYPE_NOT_SUPPORTED",
|
|
f"不支持的任务类型: {task.type}",
|
|
)
|
|
|
|
try:
|
|
question = _normalize_question(getattr(task, "question", None))
|
|
if question is None:
|
|
_normalize_captcha_type(task.captchaType)
|
|
image_bytes = _decode_image_b64(task.body or task.image or "")
|
|
if question is None and pipeline is None:
|
|
raise ServiceError(
|
|
"ERROR_NO_MODELS_LOADED",
|
|
"模型未加载,请先训练并导出 ONNX 模型",
|
|
status_code=503,
|
|
)
|
|
if question is not None and question not in fun_captcha_pipelines:
|
|
raise ServiceError(
|
|
"ERROR_NO_MODELS_LOADED",
|
|
f"专项模型未加载: {question}",
|
|
status_code=503,
|
|
)
|
|
except ServiceError as exc:
|
|
return _task_error_payload(exc.code, exc.description)
|
|
|
|
client = getattr(request, "client", None)
|
|
client_ip = getattr(client, "host", None)
|
|
task_id = task_manager.create_task(
|
|
image_bytes,
|
|
task.captchaType,
|
|
question=question,
|
|
client_ip=client_ip,
|
|
task_type=task.type,
|
|
callback_url=getattr(req, "callbackUrl", None),
|
|
)
|
|
task = task_manager.get_task(task_id)
|
|
return _task_success_payload(
|
|
taskId=task_id,
|
|
status="processing",
|
|
createTime=task.created_at if task else int(time.time()),
|
|
expiresAt=task.expires_at if task else None,
|
|
)
|
|
|
|
@app.post("/getTaskResult")
|
|
@app.post("/api/v1/getTaskResult")
|
|
async def get_task_result(req: GetTaskResultRequest):
|
|
"""按 taskId 查询任务状态与结果。"""
|
|
try:
|
|
_validate_client_key(req.clientKey)
|
|
except ServiceError as exc:
|
|
return _task_error_payload(exc.code, exc.description, taskId=req.taskId)
|
|
|
|
task = task_manager.get_task(str(req.taskId))
|
|
if task is None:
|
|
return _task_error_payload(
|
|
"ERROR_TASK_NOT_FOUND",
|
|
f"任务不存在或已过期: {req.taskId}",
|
|
taskId=req.taskId,
|
|
)
|
|
|
|
return _build_task_result_payload(task)
|
|
|
|
@app.post("/getBalance")
|
|
@app.post("/api/v1/getBalance")
|
|
async def get_balance(req: BalanceRequest | None = Body(default=None)):
|
|
try:
|
|
_validate_client_key(getattr(req, "clientKey", None))
|
|
except ServiceError as exc:
|
|
return _task_error_payload(exc.code, exc.description)
|
|
return _task_success_payload(balance=SERVER_CONFIG["balance"])
|
|
|
|
return app
|