Align task API and add FunCaptcha support
This commit is contained in:
@@ -26,6 +26,7 @@ from config import (
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.model_metadata import load_model_metadata
|
||||
|
||||
|
||||
def _try_import_ort():
|
||||
@@ -66,6 +67,11 @@ class CaptchaPipeline:
|
||||
"math": MATH_CHARS,
|
||||
"3d_text": THREED_CHARS,
|
||||
}
|
||||
self._classifier_class_names = tuple(CAPTCHA_TYPES)
|
||||
self._regression_ranges = {
|
||||
"3d_rotate": REGRESSION_RANGE["3d_rotate"],
|
||||
"3d_slider": REGRESSION_RANGE["3d_slider"],
|
||||
}
|
||||
|
||||
# 回归模型类型
|
||||
self._regression_types = {"3d_rotate", "3d_slider"}
|
||||
@@ -86,6 +92,7 @@ class CaptchaPipeline:
|
||||
opts.intra_op_num_threads = 2
|
||||
|
||||
self._sessions: dict[str, "ort.InferenceSession"] = {}
|
||||
self._metadata: dict[str, dict] = {}
|
||||
for name, fname in self._model_files.items():
|
||||
path = self.models_dir / fname
|
||||
if path.exists():
|
||||
@@ -93,6 +100,7 @@ class CaptchaPipeline:
|
||||
str(path), sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self._metadata[name] = load_model_metadata(path) or {}
|
||||
|
||||
loaded = list(self._sessions.keys())
|
||||
if not loaded:
|
||||
@@ -135,7 +143,14 @@ class CaptchaPipeline:
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
|
||||
idx = int(np.argmax(logits, axis=1)[0])
|
||||
return CAPTCHA_TYPES[idx]
|
||||
class_names = tuple(
|
||||
self._metadata.get("classifier", {}).get("class_names", self._classifier_class_names)
|
||||
)
|
||||
if idx >= len(class_names):
|
||||
raise RuntimeError(
|
||||
f"分类器输出索引越界: idx={idx}, classes={len(class_names)}"
|
||||
)
|
||||
return class_names[idx]
|
||||
|
||||
def solve(
|
||||
self,
|
||||
@@ -182,14 +197,17 @@ class CaptchaPipeline:
|
||||
# 回归模型: 输出 (batch, 1) sigmoid 值
|
||||
output = session.run(None, {input_name: inp})[0] # (1, 1)
|
||||
sigmoid_val = float(output[0, 0])
|
||||
lo, hi = REGRESSION_RANGE[captcha_type]
|
||||
lo, hi = self._metadata.get(captcha_type, {}).get(
|
||||
"label_range",
|
||||
self._regression_ranges[captcha_type],
|
||||
)
|
||||
real_val = sigmoid_val * (hi - lo) + lo
|
||||
raw_text = f"{real_val:.1f}"
|
||||
result = str(int(round(real_val)))
|
||||
else:
|
||||
# CTC 模型
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
chars = self._chars[captcha_type]
|
||||
chars = self._metadata.get(captcha_type, {}).get("chars", self._chars[captcha_type])
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 后处理
|
||||
|
||||
Reference in New Issue
Block a user