Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,10 +17,12 @@ from config import (
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
NUM_CAPTCHA_TYPES,
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
|
||||
|
||||
def export_model(
|
||||
@@ -34,7 +36,7 @@ def export_model(
|
||||
|
||||
Args:
|
||||
model: 已加载权重的 PyTorch 模型
|
||||
model_name: 模型名 (classifier / normal / math / threed)
|
||||
model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
|
||||
input_shape: 输入形状 (C, H, W)
|
||||
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
|
||||
"""
|
||||
@@ -50,7 +52,7 @@ def export_model(
|
||||
dummy = torch.randn(1, *input_shape)
|
||||
|
||||
# 分类器和识别器的 dynamic_axes 不同
|
||||
if model_name == "classifier":
|
||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
@@ -78,7 +80,8 @@ def _load_and_export(model_name: str):
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
|
||||
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
|
||||
|
||||
if model_name == "classifier":
|
||||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
|
||||
@@ -94,11 +97,19 @@ def _load_and_export(model_name: str):
|
||||
h, w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed":
|
||||
elif model_name == "threed_text":
|
||||
chars = ckpt.get("chars", THREED_CHARS)
|
||||
h, w = IMAGE_SIZE["3d"]
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed_rotate":
|
||||
h, w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed_slider":
|
||||
h, w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
else:
|
||||
print(f"[错误] 未知模型: {model_name}")
|
||||
return
|
||||
@@ -108,11 +119,11 @@ def _load_and_export(model_name: str):
|
||||
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed 四个模型。"""
|
||||
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
|
||||
print("=" * 50)
|
||||
print("导出全部 ONNX 模型")
|
||||
print("=" * 50)
|
||||
for name in ["classifier", "normal", "math", "threed"]:
|
||||
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
|
||||
_load_and_export(name)
|
||||
print("\n全部导出完成。")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user