Initialize repository
This commit is contained in:
231
inference/pipeline.py
Normal file
231
inference/pipeline.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
核心推理流水线
|
||||
|
||||
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
|
||||
|
||||
推理流程:
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
|
||||
|
||||
对算式类型,解码后还会调用 math_eval 计算结果。
|
||||
"""
|
||||
|
||||
import io
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import (
|
||||
CAPTCHA_TYPES,
|
||||
IMAGE_SIZE,
|
||||
INFERENCE_CONFIG,
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
)
|
||||
from inference.math_eval import eval_captcha_math
|
||||
|
||||
|
||||
def _try_import_ort():
|
||||
"""延迟导入 onnxruntime,给出友好错误提示。"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
return ort
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"推理需要 onnxruntime,请安装: uv pip install onnxruntime"
|
||||
)
|
||||
|
||||
|
||||
class CaptchaPipeline:
|
||||
"""
|
||||
核心推理流水线。
|
||||
|
||||
加载调度模型和所有专家模型 (ONNX 格式)。
|
||||
提供统一的 solve(image) 接口。
|
||||
"""
|
||||
|
||||
def __init__(self, models_dir: str | None = None):
|
||||
"""
|
||||
初始化加载所有 ONNX 模型。
|
||||
|
||||
Args:
|
||||
models_dir: ONNX 模型目录,默认使用 config 中的路径
|
||||
"""
|
||||
ort = _try_import_ort()
|
||||
|
||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||
self.mean = INFERENCE_CONFIG["normalize_mean"]
|
||||
self.std = INFERENCE_CONFIG["normalize_std"]
|
||||
|
||||
# 字符集映射
|
||||
self._chars = {
|
||||
"normal": NORMAL_CHARS,
|
||||
"math": MATH_CHARS,
|
||||
"3d": THREED_CHARS,
|
||||
}
|
||||
|
||||
# 专家模型名 → ONNX 文件名
|
||||
self._model_files = {
|
||||
"classifier": "classifier.onnx",
|
||||
"normal": "normal.onnx",
|
||||
"math": "math.onnx",
|
||||
"3d": "threed.onnx",
|
||||
}
|
||||
|
||||
# 加载所有可用模型
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 2
|
||||
|
||||
self._sessions: dict[str, "ort.InferenceSession"] = {}
|
||||
for name, fname in self._model_files.items():
|
||||
path = self.models_dir / fname
|
||||
if path.exists():
|
||||
self._sessions[name] = ort.InferenceSession(
|
||||
str(path), sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
loaded = list(self._sessions.keys())
|
||||
if not loaded:
|
||||
raise FileNotFoundError(
|
||||
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 公共接口
|
||||
# ----------------------------------------------------------
|
||||
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||
"""
|
||||
图片预处理: resize, grayscale, normalize, 转 numpy。
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
target_size: (H, W)
|
||||
|
||||
Returns:
|
||||
(1, 1, H, W) float32 ndarray
|
||||
"""
|
||||
h, w = target_size
|
||||
img = image.convert("L").resize((w, h), Image.BILINEAR)
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = (arr - self.mean) / self.std
|
||||
return arr.reshape(1, 1, h, w)
|
||||
|
||||
def classify(self, image: Image.Image) -> str:
|
||||
"""
|
||||
调度分类,返回类型名: 'normal' / 'math' / '3d'。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 分类器模型未加载
|
||||
"""
|
||||
if "classifier" not in self._sessions:
|
||||
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
|
||||
|
||||
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
|
||||
session = self._sessions["classifier"]
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
|
||||
idx = int(np.argmax(logits, axis=1)[0])
|
||||
return CAPTCHA_TYPES[idx]
|
||||
|
||||
def solve(
|
||||
self,
|
||||
image,
|
||||
captcha_type: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
完整识别流程。
|
||||
|
||||
Args:
|
||||
image: PIL.Image 或文件路径 (str/Path) 或 bytes
|
||||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"type": str, # 验证码类型
|
||||
"raw": str, # OCR 原始识别结果
|
||||
"result": str, # 最终答案 (算式型为计算结果)
|
||||
"time_ms": float, # 推理耗时 (毫秒)
|
||||
}
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# 1. 解析输入
|
||||
img = self._load_image(image)
|
||||
|
||||
# 2. 分类
|
||||
if captcha_type is None:
|
||||
captcha_type = self.classify(img)
|
||||
|
||||
# 3. 路由到专家模型
|
||||
if captcha_type not in self._sessions:
|
||||
raise RuntimeError(
|
||||
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
|
||||
)
|
||||
|
||||
size_key = captcha_type # "normal"/"math"/"3d"
|
||||
inp = self.preprocess(img, IMAGE_SIZE[size_key])
|
||||
session = self._sessions[captcha_type]
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
|
||||
# 4. CTC 贪心解码
|
||||
chars = self._chars[captcha_type]
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 5. 后处理
|
||||
if captcha_type == "math":
|
||||
try:
|
||||
result = eval_captcha_math(raw_text)
|
||||
except ValueError:
|
||||
result = raw_text # 解析失败则返回原始文本
|
||||
else:
|
||||
result = raw_text
|
||||
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
|
||||
return {
|
||||
"type": captcha_type,
|
||||
"raw": raw_text,
|
||||
"result": result,
|
||||
"time_ms": round(elapsed, 2),
|
||||
}
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 私有方法
|
||||
# ----------------------------------------------------------
|
||||
@staticmethod
|
||||
def _load_image(image) -> Image.Image:
|
||||
"""将多种输入类型统一转为 PIL Image。"""
|
||||
if isinstance(image, Image.Image):
|
||||
return image
|
||||
if isinstance(image, (str, Path)):
|
||||
return Image.open(image).convert("RGB")
|
||||
if isinstance(image, bytes):
|
||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||
|
||||
@staticmethod
|
||||
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
|
||||
"""
|
||||
CTC 贪心解码 (numpy 版本)。
|
||||
|
||||
Args:
|
||||
logits: (T, B, C) ONNX 输出
|
||||
chars: 字符集 (不含 blank, blank=index 0)
|
||||
|
||||
Returns:
|
||||
解码后的字符串
|
||||
"""
|
||||
# 取 batch=0
|
||||
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
|
||||
decoded = []
|
||||
prev = -1
|
||||
for idx in preds:
|
||||
if idx != 0 and idx != prev:
|
||||
decoded.append(chars[idx - 1])
|
||||
prev = idx
|
||||
return "".join(decoded)
|
||||
Reference in New Issue
Block a user