232 lines
7.0 KiB
Python
232 lines
7.0 KiB
Python
"""
|
||
核心推理流水线
|
||
|
||
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
|
||
|
||
推理流程:
|
||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
|
||
|
||
对算式类型,解码后还会调用 math_eval 计算结果。
|
||
"""
|
||
|
||
import io
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
from config import (
|
||
CAPTCHA_TYPES,
|
||
IMAGE_SIZE,
|
||
INFERENCE_CONFIG,
|
||
NORMAL_CHARS,
|
||
MATH_CHARS,
|
||
THREED_CHARS,
|
||
)
|
||
from inference.math_eval import eval_captcha_math
|
||
|
||
|
||
def _try_import_ort():
|
||
"""延迟导入 onnxruntime,给出友好错误提示。"""
|
||
try:
|
||
import onnxruntime as ort
|
||
return ort
|
||
except ImportError:
|
||
raise ImportError(
|
||
"推理需要 onnxruntime,请安装: uv pip install onnxruntime"
|
||
)
|
||
|
||
|
||
class CaptchaPipeline:
|
||
"""
|
||
核心推理流水线。
|
||
|
||
加载调度模型和所有专家模型 (ONNX 格式)。
|
||
提供统一的 solve(image) 接口。
|
||
"""
|
||
|
||
def __init__(self, models_dir: str | None = None):
|
||
"""
|
||
初始化加载所有 ONNX 模型。
|
||
|
||
Args:
|
||
models_dir: ONNX 模型目录,默认使用 config 中的路径
|
||
"""
|
||
ort = _try_import_ort()
|
||
|
||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||
self.mean = INFERENCE_CONFIG["normalize_mean"]
|
||
self.std = INFERENCE_CONFIG["normalize_std"]
|
||
|
||
# 字符集映射
|
||
self._chars = {
|
||
"normal": NORMAL_CHARS,
|
||
"math": MATH_CHARS,
|
||
"3d": THREED_CHARS,
|
||
}
|
||
|
||
# 专家模型名 → ONNX 文件名
|
||
self._model_files = {
|
||
"classifier": "classifier.onnx",
|
||
"normal": "normal.onnx",
|
||
"math": "math.onnx",
|
||
"3d": "threed.onnx",
|
||
}
|
||
|
||
# 加载所有可用模型
|
||
opts = ort.SessionOptions()
|
||
opts.inter_op_num_threads = 1
|
||
opts.intra_op_num_threads = 2
|
||
|
||
self._sessions: dict[str, "ort.InferenceSession"] = {}
|
||
for name, fname in self._model_files.items():
|
||
path = self.models_dir / fname
|
||
if path.exists():
|
||
self._sessions[name] = ort.InferenceSession(
|
||
str(path), sess_options=opts,
|
||
providers=["CPUExecutionProvider"],
|
||
)
|
||
|
||
loaded = list(self._sessions.keys())
|
||
if not loaded:
|
||
raise FileNotFoundError(
|
||
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 公共接口
|
||
# ----------------------------------------------------------
|
||
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||
"""
|
||
图片预处理: resize, grayscale, normalize, 转 numpy。
|
||
|
||
Args:
|
||
image: PIL Image
|
||
target_size: (H, W)
|
||
|
||
Returns:
|
||
(1, 1, H, W) float32 ndarray
|
||
"""
|
||
h, w = target_size
|
||
img = image.convert("L").resize((w, h), Image.BILINEAR)
|
||
arr = np.array(img, dtype=np.float32) / 255.0
|
||
arr = (arr - self.mean) / self.std
|
||
return arr.reshape(1, 1, h, w)
|
||
|
||
def classify(self, image: Image.Image) -> str:
|
||
"""
|
||
调度分类,返回类型名: 'normal' / 'math' / '3d'。
|
||
|
||
Raises:
|
||
RuntimeError: 分类器模型未加载
|
||
"""
|
||
if "classifier" not in self._sessions:
|
||
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
|
||
|
||
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
|
||
session = self._sessions["classifier"]
|
||
input_name = session.get_inputs()[0].name
|
||
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
|
||
idx = int(np.argmax(logits, axis=1)[0])
|
||
return CAPTCHA_TYPES[idx]
|
||
|
||
def solve(
|
||
self,
|
||
image,
|
||
captcha_type: str | None = None,
|
||
) -> dict:
|
||
"""
|
||
完整识别流程。
|
||
|
||
Args:
|
||
image: PIL.Image 或文件路径 (str/Path) 或 bytes
|
||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
|
||
|
||
Returns:
|
||
dict: {
|
||
"type": str, # 验证码类型
|
||
"raw": str, # OCR 原始识别结果
|
||
"result": str, # 最终答案 (算式型为计算结果)
|
||
"time_ms": float, # 推理耗时 (毫秒)
|
||
}
|
||
"""
|
||
t0 = time.perf_counter()
|
||
|
||
# 1. 解析输入
|
||
img = self._load_image(image)
|
||
|
||
# 2. 分类
|
||
if captcha_type is None:
|
||
captcha_type = self.classify(img)
|
||
|
||
# 3. 路由到专家模型
|
||
if captcha_type not in self._sessions:
|
||
raise RuntimeError(
|
||
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
|
||
)
|
||
|
||
size_key = captcha_type # "normal"/"math"/"3d"
|
||
inp = self.preprocess(img, IMAGE_SIZE[size_key])
|
||
session = self._sessions[captcha_type]
|
||
input_name = session.get_inputs()[0].name
|
||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||
|
||
# 4. CTC 贪心解码
|
||
chars = self._chars[captcha_type]
|
||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||
|
||
# 5. 后处理
|
||
if captcha_type == "math":
|
||
try:
|
||
result = eval_captcha_math(raw_text)
|
||
except ValueError:
|
||
result = raw_text # 解析失败则返回原始文本
|
||
else:
|
||
result = raw_text
|
||
|
||
elapsed = (time.perf_counter() - t0) * 1000
|
||
|
||
return {
|
||
"type": captcha_type,
|
||
"raw": raw_text,
|
||
"result": result,
|
||
"time_ms": round(elapsed, 2),
|
||
}
|
||
|
||
# ----------------------------------------------------------
|
||
# 私有方法
|
||
# ----------------------------------------------------------
|
||
@staticmethod
|
||
def _load_image(image) -> Image.Image:
|
||
"""将多种输入类型统一转为 PIL Image。"""
|
||
if isinstance(image, Image.Image):
|
||
return image
|
||
if isinstance(image, (str, Path)):
|
||
return Image.open(image).convert("RGB")
|
||
if isinstance(image, bytes):
|
||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||
|
||
@staticmethod
|
||
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
|
||
"""
|
||
CTC 贪心解码 (numpy 版本)。
|
||
|
||
Args:
|
||
logits: (T, B, C) ONNX 输出
|
||
chars: 字符集 (不含 blank, blank=index 0)
|
||
|
||
Returns:
|
||
解码后的字符串
|
||
"""
|
||
# 取 batch=0
|
||
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
|
||
decoded = []
|
||
prev = -1
|
||
for idx in preds:
|
||
if idx != 0 and idx != prev:
|
||
decoded.append(chars[idx - 1])
|
||
prev = idx
|
||
return "".join(decoded)
|