Support external FunCaptcha ONNX fallback
This commit is contained in:
@@ -21,14 +21,14 @@ Use `uv` for environment and dependency management.
|
||||
- `uv run captcha export --model 3d_text` maps to `threed_text`. The export loader also accepts internal artifact names such as `threed_rotate`, `gap_detector`, `rotation_regressor`, and `funcaptcha_rollball_animals`; `4_3d_rollball_animals` is accepted as an alias for that FunCaptcha artifact.
|
||||
- `uv run captcha predict image.png` runs auto-routing inference. Add `--type normal` to skip classification.
|
||||
- `uv run captcha predict-dir ./test_images` runs batch inference for `.png` and `.jpg` files.
|
||||
- `uv run captcha predict-funcaptcha image.jpg --question 4_3d_rollball_animals` runs the dedicated FunCaptcha matcher and returns `objects`.
|
||||
- `uv run captcha predict-funcaptcha image.jpg --question 4_3d_rollball_animals` runs the dedicated FunCaptcha matcher and returns `objects`. It resolves the ONNX in this order: `onnx_models/funcaptcha_rollball_animals.onnx` -> env `FUNCAPTCHA_ROLLBALL_MODEL_PATH` -> configured fallback path such as the sibling `funcaptcha-server/model/4_3d_rollball_animals.onnx`.
|
||||
- `uv run captcha solve slide --bg bg.png [--tpl tpl.png]` runs the slide solver. It uses template matching first when `--tpl` is provided, then OpenCV edge detection, then CNN fallback.
|
||||
- `uv run captcha solve rotate --image img.png` runs the rotate solver.
|
||||
- `uv run captcha serve --host 0.0.0.0 --port 8080` starts the implemented FastAPI service in `server.py`. It supports synchronous `/solve` and `/solve/upload`, plus async task endpoints `/createTask`, `/getTaskResult`, and `/getBalance`, with `/api/v1/*` compatibility aliases. If `CLIENT_KEY` is set in the environment, task endpoints require a matching `clientKey`. `createTask` accepts `callbackUrl`, `softId`, `languagePool`, and optional `task.question`; `task.question=4_3d_rollball_animals` routes to the dedicated FunCaptcha matcher and returns `solution.objects`. `callbackUrl` receives a form-encoded completion callback with configurable retry/backoff in `SERVER_CONFIG`. If `CALLBACK_SIGNING_SECRET` is set, callback requests include HMAC-SHA256 signature headers. Task responses also expose extra `task` / `callback` metadata for async debugging, and task state is persisted under `data/server_tasks/`.
|
||||
- `uv run pytest` runs the test suite.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Target Python 3.10-3.12 and follow the existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep public captcha type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Internal checkpoint/ONNX artifact names use `threed_text`, `threed_rotate`, `threed_slider`, and `funcaptcha_rollball_animals`; solver artifacts are `gap_detector` and `rotation_regressor`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ONNX ops, and greedy CTC decoding for OCR models. `normal` uses `NORMAL_CHARS`, `math` uses `MATH_CHARS` and must be post-processed through `inference/math_eval.py`, and `3d_text` uses `THREED_CHARS`. `3d_rotate` and `3d_slider` output sigmoid values in `[0, 1]` and scale them with `REGRESSION_RANGE`; the rotate solver model outputs `(sin, cos)` on RGB input. The FunCaptcha matcher is a dual-input RGB Siamese model keyed by `task.question`, not by `captchaType`.
|
||||
Target Python 3.10-3.12 and follow the existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep public captcha type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Internal checkpoint/ONNX artifact names use `threed_text`, `threed_rotate`, `threed_slider`, and `funcaptcha_rollball_animals`; solver artifacts are `gap_detector` and `rotation_regressor`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ONNX ops, and greedy CTC decoding for OCR models. `normal` uses `NORMAL_CHARS`, `math` uses `MATH_CHARS` and must be post-processed through `inference/math_eval.py`, and `3d_text` uses `THREED_CHARS`. `3d_rotate` and `3d_slider` output sigmoid values in `[0, 1]` and scale them with `REGRESSION_RANGE`; the rotate solver model outputs `(sin, cos)` on RGB input. The FunCaptcha matcher is a dual-input RGB Siamese model keyed by `task.question`, not by `captchaType`. Runtime ONNX artifacts belong under `onnx_models/`, not `models/`; external FunCaptcha ONNX files may omit metadata, in which case inference must preserve the external preprocessing contract instead of assuming the repo's centered RGB normalization.
|
||||
- Do not casually upgrade `torch` or `torchvision`: newer CUDA 12.8 wheels in this repo's previous environment dropped `sm_61` kernels and failed on GTX 1050 Ti. Re-verify GPU execution before changing the pinned pair.
|
||||
|
||||
## Training & Data Rules
|
||||
|
||||
12
CLAUDE.md
12
CLAUDE.md
@@ -484,11 +484,23 @@ uv run python cli.py predict image.png --type normal # 跳过分类直接
|
||||
uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型
|
||||
uv run python cli.py predict-dir ./test_images/ # 批量识别
|
||||
uv run python cli.py predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||
FUNCAPTCHA_ROLLBALL_MODEL_PATH=/path/to/4_3d_rollball_animals.onnx \
|
||||
uv run python cli.py predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||
|
||||
# 启动 HTTP 服务 (需先安装 server 可选依赖)
|
||||
uv run python cli.py serve --port 8080
|
||||
```
|
||||
|
||||
FunCaptcha 专项 ONNX 查找顺序:
|
||||
- `onnx_models/funcaptcha_rollball_animals.onnx`
|
||||
- 环境变量 `FUNCAPTCHA_ROLLBALL_MODEL_PATH`
|
||||
- 默认回退 `/mnt/data/code/python/funcaptcha-server/model/4_3d_rollball_animals.onnx`
|
||||
|
||||
注意:
|
||||
- `models/` 目录用于 Python 源码,不放 ONNX 产物
|
||||
- 本仓库导出的 FunCaptcha ONNX 会携带 sidecar metadata 并使用 centered RGB 预处理
|
||||
- 外部兼容 ONNX 若无 metadata,推理层需回退到 `funcaptcha-server` 的 `/255.0` RGB 预处理契约
|
||||
|
||||
## HTTP 服务 (server.py,可选)
|
||||
|
||||
纯推理服务,不依赖 torch / 训练代码,仅需 onnxruntime + FastAPI。
|
||||
|
||||
14
README.md
14
README.md
@@ -143,6 +143,20 @@ uv run captcha export --model 4_3d_rollball_animals
|
||||
uv run captcha predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||
```
|
||||
|
||||
如果暂时没有训练数据,也可以直接复用外部 ONNX:
|
||||
|
||||
```bash
|
||||
FUNCAPTCHA_ROLLBALL_MODEL_PATH=/path/to/4_3d_rollball_animals.onnx \
|
||||
uv run captcha predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||
```
|
||||
|
||||
推理查找顺序为:
|
||||
- `onnx_models/funcaptcha_rollball_animals.onnx`
|
||||
- 环境变量 `FUNCAPTCHA_ROLLBALL_MODEL_PATH`
|
||||
- 默认回退 `/mnt/data/code/python/funcaptcha-server/model/4_3d_rollball_animals.onnx`
|
||||
|
||||
不要把 ONNX 文件放到 `models/`;该目录用于 Python 模型定义源码,运行时模型产物应放在 `onnx_models/`。
|
||||
|
||||
## HTTP API
|
||||
|
||||
```bash
|
||||
|
||||
@@ -254,6 +254,10 @@ FUN_CAPTCHA_TASKS = {
|
||||
"num_candidates": 4,
|
||||
"answer_index_base": 0,
|
||||
"channels": 3,
|
||||
"external_model_env": "FUNCAPTCHA_ROLLBALL_MODEL_PATH",
|
||||
"fallback_model_paths": [
|
||||
str(PROJECT_ROOT.parent / "funcaptcha-server" / "model" / "4_3d_rollball_animals.onnx"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -205,6 +205,7 @@ def _load_and_export(model_name: str):
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "funcaptcha_siamese",
|
||||
"preprocess": "rgb_centered",
|
||||
"question": question,
|
||||
"num_candidates": int(ckpt.get("num_candidates", task_cfg["num_candidates"])),
|
||||
"tile_size": list(ckpt.get("tile_size", task_cfg["tile_size"])),
|
||||
|
||||
@@ -5,6 +5,7 @@ FunCaptcha 专项 ONNX 推理。
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -32,9 +33,7 @@ class FunCaptchaRollballPipeline:
|
||||
self.question = question
|
||||
self.task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||
self.model_path = self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型: {self.model_path}")
|
||||
self.model_path = self._resolve_model_path()
|
||||
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
@@ -45,6 +44,7 @@ class FunCaptchaRollballPipeline:
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.metadata = load_model_metadata(self.model_path) or {}
|
||||
self.preprocess_mode = self._resolve_preprocess_mode(self.metadata)
|
||||
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
|
||||
self.std = float(INFERENCE_CONFIG["normalize_std"])
|
||||
self.answer_index_base = int(
|
||||
@@ -57,14 +57,12 @@ class FunCaptchaRollballPipeline:
|
||||
candidates, reference = self._split_challenge(challenge)
|
||||
|
||||
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
|
||||
input_names = [inp.name for inp in self.session.get_inputs()]
|
||||
input_defs = self.session.get_inputs()
|
||||
input_names = [inp.name for inp in input_defs]
|
||||
if len(input_names) != 2:
|
||||
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
|
||||
|
||||
logits = self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: ref_batch,
|
||||
})[0].reshape(-1)
|
||||
logits = self._run_model(input_defs, input_names, candidates, ref_batch)
|
||||
scores = 1.0 / (1.0 + np.exp(-logits))
|
||||
answer_idx = int(np.argmax(logits))
|
||||
selected = answer_idx + self.answer_index_base
|
||||
@@ -100,9 +98,18 @@ class FunCaptchaRollballPipeline:
|
||||
|
||||
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||
img_h, img_w = target_size
|
||||
image = image.convert("RGB").resize((img_w, img_h), Image.BILINEAR)
|
||||
image = image.convert("RGB")
|
||||
if self.preprocess_mode == "rgb_centered":
|
||||
image = image.resize((img_w, img_h), Image.BILINEAR)
|
||||
elif self.preprocess_mode == "rgb_255":
|
||||
# 对齐 funcaptcha-server 现有 ONNX 的预处理行为。
|
||||
image = image.resize((img_w, img_h))
|
||||
else:
|
||||
raise ValueError(f"不支持的 FunCaptcha 预处理模式: {self.preprocess_mode}")
|
||||
|
||||
arr = np.asarray(image, dtype=np.float32) / 255.0
|
||||
arr = (arr - self.mean) / self.std
|
||||
if self.preprocess_mode == "rgb_centered":
|
||||
arr = (arr - self.mean) / self.std
|
||||
arr = np.transpose(arr, (2, 0, 1))
|
||||
return arr.reshape(1, 3, img_h, img_w)
|
||||
|
||||
@@ -115,3 +122,69 @@ class FunCaptchaRollballPipeline:
|
||||
if isinstance(image, bytes):
|
||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
candidates = [self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"]
|
||||
|
||||
env_name = self.task_cfg.get("external_model_env")
|
||||
env_value = os.getenv(env_name) if env_name else None
|
||||
if env_value:
|
||||
candidates.append(Path(env_value).expanduser())
|
||||
|
||||
for fallback in self.task_cfg.get("fallback_model_paths", []):
|
||||
candidates.append(Path(fallback).expanduser())
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
tried = ", ".join(str(path) for path in candidates)
|
||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型,已尝试: {tried}")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_preprocess_mode(metadata: dict) -> str:
|
||||
preprocess = metadata.get("preprocess")
|
||||
if preprocess:
|
||||
return str(preprocess)
|
||||
if metadata.get("task") == "funcaptcha_siamese":
|
||||
return "rgb_centered"
|
||||
return "rgb_255"
|
||||
|
||||
def _run_model(
|
||||
self,
|
||||
input_defs,
|
||||
input_names: list[str],
|
||||
candidates: np.ndarray,
|
||||
reference_batch: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
batch_size = candidates.shape[0]
|
||||
batch_axis = None
|
||||
for input_def in input_defs:
|
||||
shape = getattr(input_def, "shape", None)
|
||||
if isinstance(shape, (list, tuple)) and shape:
|
||||
batch_axis = shape[0]
|
||||
break
|
||||
|
||||
if batch_axis in (None, "batch", "None", -1) or not isinstance(batch_axis, int):
|
||||
return self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: reference_batch,
|
||||
})[0].reshape(-1)
|
||||
|
||||
if batch_axis == batch_size:
|
||||
return self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: reference_batch,
|
||||
})[0].reshape(-1)
|
||||
|
||||
if batch_axis != 1:
|
||||
raise RuntimeError(f"专项模型不支持当前 batch 维度: expected={batch_axis} actual={batch_size}")
|
||||
|
||||
outputs = []
|
||||
for idx in range(batch_size):
|
||||
out = self.session.run(None, {
|
||||
input_names[0]: candidates[idx:idx + 1],
|
||||
input_names[1]: reference_batch[idx:idx + 1],
|
||||
})[0]
|
||||
outputs.append(np.asarray(out, dtype=np.float32).reshape(-1))
|
||||
return np.concatenate(outputs, axis=0)
|
||||
|
||||
@@ -69,18 +69,21 @@ class _FakeSessionOptions:
|
||||
|
||||
|
||||
class _FakeInput:
|
||||
def __init__(self, name):
|
||||
def __init__(self, name, shape=None):
|
||||
self.name = name
|
||||
self.shape = shape
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.path = path
|
||||
self.last_feed_dict = None
|
||||
|
||||
def get_inputs(self):
|
||||
return [_FakeInput("candidate"), _FakeInput("reference")]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
self.last_feed_dict = feed_dict
|
||||
batch_size = next(iter(feed_dict.values())).shape[0]
|
||||
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
||||
if batch_size >= 3:
|
||||
@@ -93,6 +96,29 @@ class _FakeOrt:
|
||||
InferenceSession = _FakeSession
|
||||
|
||||
|
||||
class _Batch1FakeSession(_FakeSession):
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
super().__init__(path, *args, **kwargs)
|
||||
self.run_calls = 0
|
||||
|
||||
def get_inputs(self):
|
||||
shape = [1, 3, 48, 48]
|
||||
return [_FakeInput("candidate", shape=shape), _FakeInput("reference", shape=shape)]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
self.run_calls += 1
|
||||
candidate = feed_dict["candidate"]
|
||||
reference = feed_dict["reference"]
|
||||
assert candidate.shape == (1, 3, 48, 48)
|
||||
assert reference.shape == (1, 3, 48, 48)
|
||||
return super().run(output_names, feed_dict)
|
||||
|
||||
|
||||
class _Batch1FakeOrt:
|
||||
SessionOptions = _FakeSessionOptions
|
||||
InferenceSession = _Batch1FakeSession
|
||||
|
||||
|
||||
class TestFunCaptchaPipeline:
|
||||
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
||||
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
||||
@@ -102,6 +128,7 @@ class TestFunCaptchaPipeline:
|
||||
{
|
||||
"model_name": "funcaptcha_rollball_animals",
|
||||
"task": "funcaptcha_siamese",
|
||||
"preprocess": "rgb_centered",
|
||||
"question": "4_3d_rollball_animals",
|
||||
"num_candidates": 4,
|
||||
"tile_size": [200, 200],
|
||||
@@ -121,3 +148,40 @@ class TestFunCaptchaPipeline:
|
||||
assert result["objects"] == [2]
|
||||
assert result["result"] == "2"
|
||||
assert len(result["scores"]) == 4
|
||||
assert pipeline.preprocess_mode == "rgb_centered"
|
||||
|
||||
def test_pipeline_uses_external_model_env_without_metadata(self, tmp_path, monkeypatch):
|
||||
external_model = tmp_path / "external_rollball.onnx"
|
||||
external_model.touch()
|
||||
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
||||
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
image = Image.new("RGB", (800, 400), color=(128, 128, 128))
|
||||
sample_path = tmp_path / "0_demo.png"
|
||||
image.save(sample_path)
|
||||
|
||||
empty_models_dir = tmp_path / "missing_models"
|
||||
pipeline = FunCaptchaRollballPipeline(models_dir=empty_models_dir)
|
||||
result = pipeline.solve(sample_path)
|
||||
|
||||
assert result["objects"] == [2]
|
||||
assert pipeline.model_path == external_model
|
||||
assert pipeline.preprocess_mode == "rgb_255"
|
||||
candidate = pipeline.session.last_feed_dict["candidate"]
|
||||
assert candidate.shape == (4, 3, 48, 48)
|
||||
assert candidate[0, 0, 0, 0] == pytest.approx(128 / 255.0, abs=1e-6)
|
||||
|
||||
def test_pipeline_handles_external_fixed_batch_model(self, tmp_path, monkeypatch):
|
||||
external_model = tmp_path / "external_rollball.onnx"
|
||||
external_model.touch()
|
||||
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
||||
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _Batch1FakeOrt)
|
||||
|
||||
sample_path = tmp_path / "0_demo.png"
|
||||
_build_rollball_image(sample_path, answer_idx=0)
|
||||
|
||||
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path / "missing_models")
|
||||
result = pipeline.solve(sample_path)
|
||||
|
||||
assert result["objects"] == [0]
|
||||
assert pipeline.session.run_calls == 4
|
||||
|
||||
Reference in New Issue
Block a user