Support external FunCaptcha ONNX fallback
This commit is contained in:
@@ -5,6 +5,7 @@ FunCaptcha 专项 ONNX 推理。
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -32,9 +33,7 @@ class FunCaptchaRollballPipeline:
|
||||
self.question = question
|
||||
self.task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||
self.model_path = self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型: {self.model_path}")
|
||||
self.model_path = self._resolve_model_path()
|
||||
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
@@ -45,6 +44,7 @@ class FunCaptchaRollballPipeline:
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.metadata = load_model_metadata(self.model_path) or {}
|
||||
self.preprocess_mode = self._resolve_preprocess_mode(self.metadata)
|
||||
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
|
||||
self.std = float(INFERENCE_CONFIG["normalize_std"])
|
||||
self.answer_index_base = int(
|
||||
@@ -57,14 +57,12 @@ class FunCaptchaRollballPipeline:
|
||||
candidates, reference = self._split_challenge(challenge)
|
||||
|
||||
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
|
||||
input_names = [inp.name for inp in self.session.get_inputs()]
|
||||
input_defs = self.session.get_inputs()
|
||||
input_names = [inp.name for inp in input_defs]
|
||||
if len(input_names) != 2:
|
||||
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
|
||||
|
||||
logits = self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: ref_batch,
|
||||
})[0].reshape(-1)
|
||||
logits = self._run_model(input_defs, input_names, candidates, ref_batch)
|
||||
scores = 1.0 / (1.0 + np.exp(-logits))
|
||||
answer_idx = int(np.argmax(logits))
|
||||
selected = answer_idx + self.answer_index_base
|
||||
@@ -100,9 +98,18 @@ class FunCaptchaRollballPipeline:
|
||||
|
||||
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||
img_h, img_w = target_size
|
||||
image = image.convert("RGB").resize((img_w, img_h), Image.BILINEAR)
|
||||
image = image.convert("RGB")
|
||||
if self.preprocess_mode == "rgb_centered":
|
||||
image = image.resize((img_w, img_h), Image.BILINEAR)
|
||||
elif self.preprocess_mode == "rgb_255":
|
||||
# 对齐 funcaptcha-server 现有 ONNX 的预处理行为。
|
||||
image = image.resize((img_w, img_h))
|
||||
else:
|
||||
raise ValueError(f"不支持的 FunCaptcha 预处理模式: {self.preprocess_mode}")
|
||||
|
||||
arr = np.asarray(image, dtype=np.float32) / 255.0
|
||||
arr = (arr - self.mean) / self.std
|
||||
if self.preprocess_mode == "rgb_centered":
|
||||
arr = (arr - self.mean) / self.std
|
||||
arr = np.transpose(arr, (2, 0, 1))
|
||||
return arr.reshape(1, 3, img_h, img_w)
|
||||
|
||||
@@ -115,3 +122,69 @@ class FunCaptchaRollballPipeline:
|
||||
if isinstance(image, bytes):
|
||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
candidates = [self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"]
|
||||
|
||||
env_name = self.task_cfg.get("external_model_env")
|
||||
env_value = os.getenv(env_name) if env_name else None
|
||||
if env_value:
|
||||
candidates.append(Path(env_value).expanduser())
|
||||
|
||||
for fallback in self.task_cfg.get("fallback_model_paths", []):
|
||||
candidates.append(Path(fallback).expanduser())
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
tried = ", ".join(str(path) for path in candidates)
|
||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型,已尝试: {tried}")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_preprocess_mode(metadata: dict) -> str:
|
||||
preprocess = metadata.get("preprocess")
|
||||
if preprocess:
|
||||
return str(preprocess)
|
||||
if metadata.get("task") == "funcaptcha_siamese":
|
||||
return "rgb_centered"
|
||||
return "rgb_255"
|
||||
|
||||
def _run_model(
|
||||
self,
|
||||
input_defs,
|
||||
input_names: list[str],
|
||||
candidates: np.ndarray,
|
||||
reference_batch: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
batch_size = candidates.shape[0]
|
||||
batch_axis = None
|
||||
for input_def in input_defs:
|
||||
shape = getattr(input_def, "shape", None)
|
||||
if isinstance(shape, (list, tuple)) and shape:
|
||||
batch_axis = shape[0]
|
||||
break
|
||||
|
||||
if batch_axis in (None, "batch", "None", -1) or not isinstance(batch_axis, int):
|
||||
return self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: reference_batch,
|
||||
})[0].reshape(-1)
|
||||
|
||||
if batch_axis == batch_size:
|
||||
return self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: reference_batch,
|
||||
})[0].reshape(-1)
|
||||
|
||||
if batch_axis != 1:
|
||||
raise RuntimeError(f"专项模型不支持当前 batch 维度: expected={batch_axis} actual={batch_size}")
|
||||
|
||||
outputs = []
|
||||
for idx in range(batch_size):
|
||||
out = self.session.run(None, {
|
||||
input_names[0]: candidates[idx:idx + 1],
|
||||
input_names[1]: reference_batch[idx:idx + 1],
|
||||
})[0]
|
||||
outputs.append(np.asarray(out, dtype=np.float32).reshape(-1))
|
||||
return np.concatenate(outputs, axis=0)
|
||||
|
||||
Reference in New Issue
Block a user