Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -26,6 +26,7 @@ from config import (
REGRESSION_RANGE,
)
from inference.math_eval import eval_captcha_math
from inference.model_metadata import load_model_metadata
def _try_import_ort():
@@ -66,6 +67,11 @@ class CaptchaPipeline:
"math": MATH_CHARS,
"3d_text": THREED_CHARS,
}
self._classifier_class_names = tuple(CAPTCHA_TYPES)
self._regression_ranges = {
"3d_rotate": REGRESSION_RANGE["3d_rotate"],
"3d_slider": REGRESSION_RANGE["3d_slider"],
}
# 回归模型类型
self._regression_types = {"3d_rotate", "3d_slider"}
@@ -86,6 +92,7 @@ class CaptchaPipeline:
opts.intra_op_num_threads = 2
self._sessions: dict[str, "ort.InferenceSession"] = {}
self._metadata: dict[str, dict] = {}
for name, fname in self._model_files.items():
path = self.models_dir / fname
if path.exists():
@@ -93,6 +100,7 @@ class CaptchaPipeline:
str(path), sess_options=opts,
providers=["CPUExecutionProvider"],
)
self._metadata[name] = load_model_metadata(path) or {}
loaded = list(self._sessions.keys())
if not loaded:
@@ -135,7 +143,14 @@ class CaptchaPipeline:
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
idx = int(np.argmax(logits, axis=1)[0])
return CAPTCHA_TYPES[idx]
class_names = tuple(
self._metadata.get("classifier", {}).get("class_names", self._classifier_class_names)
)
if idx >= len(class_names):
raise RuntimeError(
f"分类器输出索引越界: idx={idx}, classes={len(class_names)}"
)
return class_names[idx]
def solve(
self,
@@ -182,14 +197,17 @@ class CaptchaPipeline:
# 回归模型: 输出 (batch, 1) sigmoid 值
output = session.run(None, {input_name: inp})[0] # (1, 1)
sigmoid_val = float(output[0, 0])
lo, hi = REGRESSION_RANGE[captcha_type]
lo, hi = self._metadata.get(captcha_type, {}).get(
"label_range",
self._regression_ranges[captcha_type],
)
real_val = sigmoid_val * (hi - lo) + lo
raw_text = f"{real_val:.1f}"
result = str(int(round(real_val)))
else:
# CTC 模型
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
chars = self._chars[captcha_type]
chars = self._metadata.get(captcha_type, {}).get("chars", self._chars[captcha_type])
raw_text = self._ctc_greedy_decode(logits, chars)
# 后处理