Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,10 +17,12 @@ from config import (
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
NUM_CAPTCHA_TYPES,
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
|
||||
|
||||
def export_model(
|
||||
@@ -34,7 +36,7 @@ def export_model(
|
||||
|
||||
Args:
|
||||
model: 已加载权重的 PyTorch 模型
|
||||
model_name: 模型名 (classifier / normal / math / threed)
|
||||
model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
|
||||
input_shape: 输入形状 (C, H, W)
|
||||
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
|
||||
"""
|
||||
@@ -50,7 +52,7 @@ def export_model(
|
||||
dummy = torch.randn(1, *input_shape)
|
||||
|
||||
# 分类器和识别器的 dynamic_axes 不同
|
||||
if model_name == "classifier":
|
||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
@@ -78,7 +80,8 @@ def _load_and_export(model_name: str):
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
|
||||
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
|
||||
|
||||
if model_name == "classifier":
|
||||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
|
||||
@@ -94,11 +97,19 @@ def _load_and_export(model_name: str):
|
||||
h, w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed":
|
||||
elif model_name == "threed_text":
|
||||
chars = ckpt.get("chars", THREED_CHARS)
|
||||
h, w = IMAGE_SIZE["3d"]
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed_rotate":
|
||||
h, w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed_slider":
|
||||
h, w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
else:
|
||||
print(f"[错误] 未知模型: {model_name}")
|
||||
return
|
||||
@@ -108,11 +119,11 @@ def _load_and_export(model_name: str):
|
||||
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed 四个模型。"""
|
||||
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
|
||||
print("=" * 50)
|
||||
print("导出全部 ONNX 模型")
|
||||
print("=" * 50)
|
||||
for name in ["classifier", "normal", "math", "threed"]:
|
||||
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
|
||||
_load_and_export(name)
|
||||
print("\n全部导出完成。")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
|
||||
|
||||
推理流程:
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 / 回归缩放 → 后处理 → 输出
|
||||
|
||||
对算式类型,解码后还会调用 math_eval 计算结果。
|
||||
"""
|
||||
@@ -23,6 +23,7 @@ from config import (
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from inference.math_eval import eval_captcha_math
|
||||
|
||||
@@ -59,19 +60,24 @@ class CaptchaPipeline:
|
||||
self.mean = INFERENCE_CONFIG["normalize_mean"]
|
||||
self.std = INFERENCE_CONFIG["normalize_std"]
|
||||
|
||||
# 字符集映射
|
||||
# 字符集映射 (仅 CTC 模型需要)
|
||||
self._chars = {
|
||||
"normal": NORMAL_CHARS,
|
||||
"math": MATH_CHARS,
|
||||
"3d": THREED_CHARS,
|
||||
"3d_text": THREED_CHARS,
|
||||
}
|
||||
|
||||
# 回归模型类型
|
||||
self._regression_types = {"3d_rotate", "3d_slider"}
|
||||
|
||||
# 专家模型名 → ONNX 文件名
|
||||
self._model_files = {
|
||||
"classifier": "classifier.onnx",
|
||||
"normal": "normal.onnx",
|
||||
"math": "math.onnx",
|
||||
"3d": "threed.onnx",
|
||||
"3d_text": "threed_text.onnx",
|
||||
"3d_rotate": "threed_rotate.onnx",
|
||||
"3d_slider": "threed_slider.onnx",
|
||||
}
|
||||
|
||||
# 加载所有可用模型
|
||||
@@ -116,7 +122,7 @@ class CaptchaPipeline:
|
||||
|
||||
def classify(self, image: Image.Image) -> str:
|
||||
"""
|
||||
调度分类,返回类型名: 'normal' / 'math' / '3d'。
|
||||
调度分类,返回类型名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 分类器模型未加载
|
||||
@@ -141,7 +147,7 @@ class CaptchaPipeline:
|
||||
|
||||
Args:
|
||||
image: PIL.Image 或文件路径 (str/Path) 或 bytes
|
||||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
|
||||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d_text'/'3d_rotate'/'3d_slider')
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -166,24 +172,34 @@ class CaptchaPipeline:
|
||||
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
|
||||
)
|
||||
|
||||
size_key = captcha_type # "normal"/"math"/"3d"
|
||||
size_key = captcha_type
|
||||
inp = self.preprocess(img, IMAGE_SIZE[size_key])
|
||||
session = self._sessions[captcha_type]
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
|
||||
# 4. CTC 贪心解码
|
||||
chars = self._chars[captcha_type]
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 5. 后处理
|
||||
if captcha_type == "math":
|
||||
try:
|
||||
result = eval_captcha_math(raw_text)
|
||||
except ValueError:
|
||||
result = raw_text # 解析失败则返回原始文本
|
||||
# 4. 分支: CTC 解码 vs 回归
|
||||
if captcha_type in self._regression_types:
|
||||
# 回归模型: 输出 (batch, 1) sigmoid 值
|
||||
output = session.run(None, {input_name: inp})[0] # (1, 1)
|
||||
sigmoid_val = float(output[0, 0])
|
||||
lo, hi = REGRESSION_RANGE[captcha_type]
|
||||
real_val = sigmoid_val * (hi - lo) + lo
|
||||
raw_text = f"{real_val:.1f}"
|
||||
result = str(int(round(real_val)))
|
||||
else:
|
||||
result = raw_text
|
||||
# CTC 模型
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
chars = self._chars[captcha_type]
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 后处理
|
||||
if captcha_type == "math":
|
||||
try:
|
||||
result = eval_captcha_math(raw_text)
|
||||
except ValueError:
|
||||
result = raw_text
|
||||
else:
|
||||
result = raw_text
|
||||
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
|
||||
|
||||
Reference in New Issue
Block a user