Initialize repository

This commit is contained in:
Hua
2026-03-10 18:47:29 +08:00
commit 760b80ee5e
32 changed files with 4343 additions and 0 deletions

18
inference/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
"""
推理包
- pipeline.py: CaptchaPipeline 核心推理流水线
- export_onnx.py: PyTorch → ONNX 导出
- math_eval.py: 算式计算模块
"""
from inference.pipeline import CaptchaPipeline
from inference.math_eval import eval_captcha_math
from inference.export_onnx import export_model, export_all
__all__ = [
"CaptchaPipeline",
"eval_captcha_math",
"export_model",
"export_all",
]

121
inference/export_onnx.py Normal file
View File

@@ -0,0 +1,121 @@
"""
ONNX 导出脚本
从 checkpoints/ 加载训练好的 PyTorch 模型,导出为 ONNX 格式到 onnx_models/。
支持逐个导出或一次导出全部。
"""
import torch
import torch.nn as nn
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
IMAGE_SIZE,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
NUM_CAPTCHA_TYPES,
)
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
def export_model(
model: nn.Module,
model_name: str,
input_shape: tuple,
onnx_dir: str | None = None,
):
"""
导出单个模型为 ONNX。
Args:
model: 已加载权重的 PyTorch 模型
model_name: 模型名 (classifier / normal / math / threed)
input_shape: 输入形状 (C, H, W)
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
"""
from pathlib import Path
out_dir = Path(onnx_dir) if onnx_dir else ONNX_DIR
out_dir.mkdir(parents=True, exist_ok=True)
onnx_path = out_dir / f"{model_name}.onnx"
model.eval()
model.cpu()
dummy = torch.randn(1, *input_shape)
# 分类器和识别器的 dynamic_axes 不同
if model_name == "classifier":
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
else:
# CTC 模型: output shape = (T, B, C)
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
torch.onnx.export(
model,
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None,
)
size_kb = onnx_path.stat().st_size / 1024
print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)")
def _load_and_export(model_name: str):
"""从 checkpoint 加载模型并导出 ONNX。"""
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
if not ckpt_path.exists():
print(f"[跳过] {model_name}: checkpoint 不存在 ({ckpt_path})")
return
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
if model_name == "classifier":
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
h, w = IMAGE_SIZE["classifier"]
input_shape = (1, h, w)
elif model_name == "normal":
chars = ckpt.get("chars", NORMAL_CHARS)
h, w = IMAGE_SIZE["normal"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "math":
chars = ckpt.get("chars", MATH_CHARS)
h, w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed":
chars = ckpt.get("chars", THREED_CHARS)
h, w = IMAGE_SIZE["3d"]
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
else:
print(f"[错误] 未知模型: {model_name}")
return
model.load_state_dict(ckpt["model_state_dict"])
export_model(model, model_name, input_shape)
def export_all():
"""依次导出 classifier, normal, math, threed 四个模型。"""
print("=" * 50)
print("导出全部 ONNX 模型")
print("=" * 50)
for name in ["classifier", "normal", "math", "threed"]:
_load_and_export(name)
print("\n全部导出完成。")
if __name__ == "__main__":
export_all()

66
inference/math_eval.py Normal file
View File

@@ -0,0 +1,66 @@
"""
算式计算模块
解析并计算验证码中的算式表达式。
用正则提取数字和运算符,不使用 eval()。
支持: 加减乘除,个位到两位数运算。
"""
import re
# 匹配: 数字 运算符 数字 (后面可能跟 =? 等)
_EXPR_PATTERN = re.compile(
r"(\d+)\s*([+\-×÷xX*])\s*(\d+)"
)
# 运算符归一化映射
_OP_MAP = {
"+": "+",
"-": "-",
"×": "×",
"÷": "÷",
"x": "×",
"X": "×",
"*": "×",
}
def eval_captcha_math(expr: str) -> str:
"""
解析并计算验证码算式。
支持: 加减乘除,个位到两位数运算。
输入: "3+8=?""12×3=?""15-7=?""3+8"
输出: "11""36""8"
用正则提取数字和运算符,不使用 eval()。
Raises:
ValueError: 无法解析表达式
"""
match = _EXPR_PATTERN.search(expr)
if not match:
raise ValueError(f"无法解析算式: {expr!r}")
a = int(match.group(1))
op_raw = match.group(2)
b = int(match.group(3))
op = _OP_MAP.get(op_raw, op_raw)
if op == "+":
result = a + b
elif op == "-":
result = a - b
elif op == "×":
result = a * b
elif op == "÷":
if b == 0:
raise ValueError(f"除数为零: {expr!r}")
result = a // b
else:
raise ValueError(f"不支持的运算符: {op!r} 原式: {expr!r}")
return str(result)

231
inference/pipeline.py Normal file
View File

@@ -0,0 +1,231 @@
"""
核心推理流水线
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
推理流程:
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
对算式类型,解码后还会调用 math_eval 计算结果。
"""
import io
import time
from pathlib import Path
import numpy as np
from PIL import Image
from config import (
CAPTCHA_TYPES,
IMAGE_SIZE,
INFERENCE_CONFIG,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
)
from inference.math_eval import eval_captcha_math
def _try_import_ort():
"""延迟导入 onnxruntime给出友好错误提示。"""
try:
import onnxruntime as ort
return ort
except ImportError:
raise ImportError(
"推理需要 onnxruntime请安装: uv pip install onnxruntime"
)
class CaptchaPipeline:
"""
核心推理流水线。
加载调度模型和所有专家模型 (ONNX 格式)。
提供统一的 solve(image) 接口。
"""
def __init__(self, models_dir: str | None = None):
"""
初始化加载所有 ONNX 模型。
Args:
models_dir: ONNX 模型目录,默认使用 config 中的路径
"""
ort = _try_import_ort()
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
self.mean = INFERENCE_CONFIG["normalize_mean"]
self.std = INFERENCE_CONFIG["normalize_std"]
# 字符集映射
self._chars = {
"normal": NORMAL_CHARS,
"math": MATH_CHARS,
"3d": THREED_CHARS,
}
# 专家模型名 → ONNX 文件名
self._model_files = {
"classifier": "classifier.onnx",
"normal": "normal.onnx",
"math": "math.onnx",
"3d": "threed.onnx",
}
# 加载所有可用模型
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 2
self._sessions: dict[str, "ort.InferenceSession"] = {}
for name, fname in self._model_files.items():
path = self.models_dir / fname
if path.exists():
self._sessions[name] = ort.InferenceSession(
str(path), sess_options=opts,
providers=["CPUExecutionProvider"],
)
loaded = list(self._sessions.keys())
if not loaded:
raise FileNotFoundError(
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
)
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
"""
图片预处理: resize, grayscale, normalize, 转 numpy。
Args:
image: PIL Image
target_size: (H, W)
Returns:
(1, 1, H, W) float32 ndarray
"""
h, w = target_size
img = image.convert("L").resize((w, h), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - self.mean) / self.std
return arr.reshape(1, 1, h, w)
def classify(self, image: Image.Image) -> str:
"""
调度分类,返回类型名: 'normal' / 'math' / '3d'
Raises:
RuntimeError: 分类器模型未加载
"""
if "classifier" not in self._sessions:
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
session = self._sessions["classifier"]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
idx = int(np.argmax(logits, axis=1)[0])
return CAPTCHA_TYPES[idx]
def solve(
self,
image,
captcha_type: str | None = None,
) -> dict:
"""
完整识别流程。
Args:
image: PIL.Image 或文件路径 (str/Path) 或 bytes
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
Returns:
dict: {
"type": str, # 验证码类型
"raw": str, # OCR 原始识别结果
"result": str, # 最终答案 (算式型为计算结果)
"time_ms": float, # 推理耗时 (毫秒)
}
"""
t0 = time.perf_counter()
# 1. 解析输入
img = self._load_image(image)
# 2. 分类
if captcha_type is None:
captcha_type = self.classify(img)
# 3. 路由到专家模型
if captcha_type not in self._sessions:
raise RuntimeError(
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
)
size_key = captcha_type # "normal"/"math"/"3d"
inp = self.preprocess(img, IMAGE_SIZE[size_key])
session = self._sessions[captcha_type]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
# 4. CTC 贪心解码
chars = self._chars[captcha_type]
raw_text = self._ctc_greedy_decode(logits, chars)
# 5. 后处理
if captcha_type == "math":
try:
result = eval_captcha_math(raw_text)
except ValueError:
result = raw_text # 解析失败则返回原始文本
else:
result = raw_text
elapsed = (time.perf_counter() - t0) * 1000
return {
"type": captcha_type,
"raw": raw_text,
"result": result,
"time_ms": round(elapsed, 2),
}
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
@staticmethod
def _load_image(image) -> Image.Image:
"""将多种输入类型统一转为 PIL Image。"""
if isinstance(image, Image.Image):
return image
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
if isinstance(image, bytes):
return Image.open(io.BytesIO(image)).convert("RGB")
raise TypeError(f"不支持的图片输入类型: {type(image)}")
@staticmethod
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
"""
CTC 贪心解码 (numpy 版本)。
Args:
logits: (T, B, C) ONNX 输出
chars: 字符集 (不含 blank, blank=index 0)
Returns:
解码后的字符串
"""
# 取 batch=0
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
decoded = []
prev = -1
for idx in preds:
if idx != 0 and idx != prev:
decoded.append(chars[idx - 1])
prev = idx
return "".join(decoded)