Files
CaptchBreaker/inference/pipeline.py
2026-03-10 18:47:29 +08:00

232 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
核心推理流水线
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
推理流程:
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
对算式类型,解码后还会调用 math_eval 计算结果。
"""
import io
import time
from pathlib import Path
import numpy as np
from PIL import Image
from config import (
CAPTCHA_TYPES,
IMAGE_SIZE,
INFERENCE_CONFIG,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
)
from inference.math_eval import eval_captcha_math
def _try_import_ort():
"""延迟导入 onnxruntime给出友好错误提示。"""
try:
import onnxruntime as ort
return ort
except ImportError:
raise ImportError(
"推理需要 onnxruntime请安装: uv pip install onnxruntime"
)
class CaptchaPipeline:
"""
核心推理流水线。
加载调度模型和所有专家模型 (ONNX 格式)。
提供统一的 solve(image) 接口。
"""
def __init__(self, models_dir: str | None = None):
"""
初始化加载所有 ONNX 模型。
Args:
models_dir: ONNX 模型目录,默认使用 config 中的路径
"""
ort = _try_import_ort()
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
self.mean = INFERENCE_CONFIG["normalize_mean"]
self.std = INFERENCE_CONFIG["normalize_std"]
# 字符集映射
self._chars = {
"normal": NORMAL_CHARS,
"math": MATH_CHARS,
"3d": THREED_CHARS,
}
# 专家模型名 → ONNX 文件名
self._model_files = {
"classifier": "classifier.onnx",
"normal": "normal.onnx",
"math": "math.onnx",
"3d": "threed.onnx",
}
# 加载所有可用模型
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 2
self._sessions: dict[str, "ort.InferenceSession"] = {}
for name, fname in self._model_files.items():
path = self.models_dir / fname
if path.exists():
self._sessions[name] = ort.InferenceSession(
str(path), sess_options=opts,
providers=["CPUExecutionProvider"],
)
loaded = list(self._sessions.keys())
if not loaded:
raise FileNotFoundError(
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
)
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
"""
图片预处理: resize, grayscale, normalize, 转 numpy。
Args:
image: PIL Image
target_size: (H, W)
Returns:
(1, 1, H, W) float32 ndarray
"""
h, w = target_size
img = image.convert("L").resize((w, h), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - self.mean) / self.std
return arr.reshape(1, 1, h, w)
def classify(self, image: Image.Image) -> str:
"""
调度分类,返回类型名: 'normal' / 'math' / '3d'
Raises:
RuntimeError: 分类器模型未加载
"""
if "classifier" not in self._sessions:
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
session = self._sessions["classifier"]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
idx = int(np.argmax(logits, axis=1)[0])
return CAPTCHA_TYPES[idx]
def solve(
self,
image,
captcha_type: str | None = None,
) -> dict:
"""
完整识别流程。
Args:
image: PIL.Image 或文件路径 (str/Path) 或 bytes
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
Returns:
dict: {
"type": str, # 验证码类型
"raw": str, # OCR 原始识别结果
"result": str, # 最终答案 (算式型为计算结果)
"time_ms": float, # 推理耗时 (毫秒)
}
"""
t0 = time.perf_counter()
# 1. 解析输入
img = self._load_image(image)
# 2. 分类
if captcha_type is None:
captcha_type = self.classify(img)
# 3. 路由到专家模型
if captcha_type not in self._sessions:
raise RuntimeError(
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
)
size_key = captcha_type # "normal"/"math"/"3d"
inp = self.preprocess(img, IMAGE_SIZE[size_key])
session = self._sessions[captcha_type]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
# 4. CTC 贪心解码
chars = self._chars[captcha_type]
raw_text = self._ctc_greedy_decode(logits, chars)
# 5. 后处理
if captcha_type == "math":
try:
result = eval_captcha_math(raw_text)
except ValueError:
result = raw_text # 解析失败则返回原始文本
else:
result = raw_text
elapsed = (time.perf_counter() - t0) * 1000
return {
"type": captcha_type,
"raw": raw_text,
"result": result,
"time_ms": round(elapsed, 2),
}
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
@staticmethod
def _load_image(image) -> Image.Image:
"""将多种输入类型统一转为 PIL Image。"""
if isinstance(image, Image.Image):
return image
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
if isinstance(image, bytes):
return Image.open(io.BytesIO(image)).convert("RGB")
raise TypeError(f"不支持的图片输入类型: {type(image)}")
@staticmethod
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
"""
CTC 贪心解码 (numpy 版本)。
Args:
logits: (T, B, C) ONNX 输出
chars: 字符集 (不含 blank, blank=index 0)
Returns:
解码后的字符串
"""
# 取 batch=0
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
decoded = []
prev = -1
for idx in preds:
if idx != 0 and idx != prev:
decoded.append(chars[idx - 1])
prev = idx
return "".join(decoded)