Initialize repository

This commit is contained in:
Hua
2026-03-10 18:47:29 +08:00
commit 760b80ee5e
32 changed files with 4343 additions and 0 deletions

231
inference/pipeline.py Normal file
View File

@@ -0,0 +1,231 @@
"""
核心推理流水线
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
推理流程:
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
对算式类型,解码后还会调用 math_eval 计算结果。
"""
import io
import time
from pathlib import Path
import numpy as np
from PIL import Image
from config import (
CAPTCHA_TYPES,
IMAGE_SIZE,
INFERENCE_CONFIG,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
)
from inference.math_eval import eval_captcha_math
def _try_import_ort():
"""延迟导入 onnxruntime给出友好错误提示。"""
try:
import onnxruntime as ort
return ort
except ImportError:
raise ImportError(
"推理需要 onnxruntime请安装: uv pip install onnxruntime"
)
class CaptchaPipeline:
"""
核心推理流水线。
加载调度模型和所有专家模型 (ONNX 格式)。
提供统一的 solve(image) 接口。
"""
def __init__(self, models_dir: str | None = None):
"""
初始化加载所有 ONNX 模型。
Args:
models_dir: ONNX 模型目录,默认使用 config 中的路径
"""
ort = _try_import_ort()
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
self.mean = INFERENCE_CONFIG["normalize_mean"]
self.std = INFERENCE_CONFIG["normalize_std"]
# 字符集映射
self._chars = {
"normal": NORMAL_CHARS,
"math": MATH_CHARS,
"3d": THREED_CHARS,
}
# 专家模型名 → ONNX 文件名
self._model_files = {
"classifier": "classifier.onnx",
"normal": "normal.onnx",
"math": "math.onnx",
"3d": "threed.onnx",
}
# 加载所有可用模型
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 2
self._sessions: dict[str, "ort.InferenceSession"] = {}
for name, fname in self._model_files.items():
path = self.models_dir / fname
if path.exists():
self._sessions[name] = ort.InferenceSession(
str(path), sess_options=opts,
providers=["CPUExecutionProvider"],
)
loaded = list(self._sessions.keys())
if not loaded:
raise FileNotFoundError(
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
)
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
"""
图片预处理: resize, grayscale, normalize, 转 numpy。
Args:
image: PIL Image
target_size: (H, W)
Returns:
(1, 1, H, W) float32 ndarray
"""
h, w = target_size
img = image.convert("L").resize((w, h), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - self.mean) / self.std
return arr.reshape(1, 1, h, w)
def classify(self, image: Image.Image) -> str:
"""
调度分类,返回类型名: 'normal' / 'math' / '3d'
Raises:
RuntimeError: 分类器模型未加载
"""
if "classifier" not in self._sessions:
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
session = self._sessions["classifier"]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
idx = int(np.argmax(logits, axis=1)[0])
return CAPTCHA_TYPES[idx]
def solve(
self,
image,
captcha_type: str | None = None,
) -> dict:
"""
完整识别流程。
Args:
image: PIL.Image 或文件路径 (str/Path) 或 bytes
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
Returns:
dict: {
"type": str, # 验证码类型
"raw": str, # OCR 原始识别结果
"result": str, # 最终答案 (算式型为计算结果)
"time_ms": float, # 推理耗时 (毫秒)
}
"""
t0 = time.perf_counter()
# 1. 解析输入
img = self._load_image(image)
# 2. 分类
if captcha_type is None:
captcha_type = self.classify(img)
# 3. 路由到专家模型
if captcha_type not in self._sessions:
raise RuntimeError(
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
)
size_key = captcha_type # "normal"/"math"/"3d"
inp = self.preprocess(img, IMAGE_SIZE[size_key])
session = self._sessions[captcha_type]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
# 4. CTC 贪心解码
chars = self._chars[captcha_type]
raw_text = self._ctc_greedy_decode(logits, chars)
# 5. 后处理
if captcha_type == "math":
try:
result = eval_captcha_math(raw_text)
except ValueError:
result = raw_text # 解析失败则返回原始文本
else:
result = raw_text
elapsed = (time.perf_counter() - t0) * 1000
return {
"type": captcha_type,
"raw": raw_text,
"result": result,
"time_ms": round(elapsed, 2),
}
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
@staticmethod
def _load_image(image) -> Image.Image:
"""将多种输入类型统一转为 PIL Image。"""
if isinstance(image, Image.Image):
return image
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
if isinstance(image, bytes):
return Image.open(io.BytesIO(image)).convert("RGB")
raise TypeError(f"不支持的图片输入类型: {type(image)}")
@staticmethod
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
"""
CTC 贪心解码 (numpy 版本)。
Args:
logits: (T, B, C) ONNX 输出
chars: 字符集 (不含 blank, blank=index 0)
Returns:
解码后的字符串
"""
# 取 batch=0
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
decoded = []
prev = -1
for idx in preds:
if idx != 0 and idx != prev:
decoded.append(chars[idx - 1])
prev = idx
return "".join(decoded)