Support external FunCaptcha ONNX fallback
This commit is contained in:
@@ -21,14 +21,14 @@ Use `uv` for environment and dependency management.
|
|||||||
- `uv run captcha export --model 3d_text` maps to `threed_text`. The export loader also accepts internal artifact names such as `threed_rotate`, `gap_detector`, `rotation_regressor`, and `funcaptcha_rollball_animals`; `4_3d_rollball_animals` is accepted as an alias for that FunCaptcha artifact.
|
- `uv run captcha export --model 3d_text` maps to `threed_text`. The export loader also accepts internal artifact names such as `threed_rotate`, `gap_detector`, `rotation_regressor`, and `funcaptcha_rollball_animals`; `4_3d_rollball_animals` is accepted as an alias for that FunCaptcha artifact.
|
||||||
- `uv run captcha predict image.png` runs auto-routing inference. Add `--type normal` to skip classification.
|
- `uv run captcha predict image.png` runs auto-routing inference. Add `--type normal` to skip classification.
|
||||||
- `uv run captcha predict-dir ./test_images` runs batch inference for `.png` and `.jpg` files.
|
- `uv run captcha predict-dir ./test_images` runs batch inference for `.png` and `.jpg` files.
|
||||||
- `uv run captcha predict-funcaptcha image.jpg --question 4_3d_rollball_animals` runs the dedicated FunCaptcha matcher and returns `objects`.
|
- `uv run captcha predict-funcaptcha image.jpg --question 4_3d_rollball_animals` runs the dedicated FunCaptcha matcher and returns `objects`. It resolves the ONNX in this order: `onnx_models/funcaptcha_rollball_animals.onnx` -> env `FUNCAPTCHA_ROLLBALL_MODEL_PATH` -> configured fallback path such as the sibling `funcaptcha-server/model/4_3d_rollball_animals.onnx`.
|
||||||
- `uv run captcha solve slide --bg bg.png [--tpl tpl.png]` runs the slide solver. It uses template matching first when `--tpl` is provided, then OpenCV edge detection, then CNN fallback.
|
- `uv run captcha solve slide --bg bg.png [--tpl tpl.png]` runs the slide solver. It uses template matching first when `--tpl` is provided, then OpenCV edge detection, then CNN fallback.
|
||||||
- `uv run captcha solve rotate --image img.png` runs the rotate solver.
|
- `uv run captcha solve rotate --image img.png` runs the rotate solver.
|
||||||
- `uv run captcha serve --host 0.0.0.0 --port 8080` starts the implemented FastAPI service in `server.py`. It supports synchronous `/solve` and `/solve/upload`, plus async task endpoints `/createTask`, `/getTaskResult`, and `/getBalance`, with `/api/v1/*` compatibility aliases. If `CLIENT_KEY` is set in the environment, task endpoints require a matching `clientKey`. `createTask` accepts `callbackUrl`, `softId`, `languagePool`, and optional `task.question`; `task.question=4_3d_rollball_animals` routes to the dedicated FunCaptcha matcher and returns `solution.objects`. `callbackUrl` receives a form-encoded completion callback with configurable retry/backoff in `SERVER_CONFIG`. If `CALLBACK_SIGNING_SECRET` is set, callback requests include HMAC-SHA256 signature headers. Task responses also expose extra `task` / `callback` metadata for async debugging, and task state is persisted under `data/server_tasks/`.
|
- `uv run captcha serve --host 0.0.0.0 --port 8080` starts the implemented FastAPI service in `server.py`. It supports synchronous `/solve` and `/solve/upload`, plus async task endpoints `/createTask`, `/getTaskResult`, and `/getBalance`, with `/api/v1/*` compatibility aliases. If `CLIENT_KEY` is set in the environment, task endpoints require a matching `clientKey`. `createTask` accepts `callbackUrl`, `softId`, `languagePool`, and optional `task.question`; `task.question=4_3d_rollball_animals` routes to the dedicated FunCaptcha matcher and returns `solution.objects`. `callbackUrl` receives a form-encoded completion callback with configurable retry/backoff in `SERVER_CONFIG`. If `CALLBACK_SIGNING_SECRET` is set, callback requests include HMAC-SHA256 signature headers. Task responses also expose extra `task` / `callback` metadata for async debugging, and task state is persisted under `data/server_tasks/`.
|
||||||
- `uv run pytest` runs the test suite.
|
- `uv run pytest` runs the test suite.
|
||||||
|
|
||||||
## Coding Style & Naming Conventions
|
## Coding Style & Naming Conventions
|
||||||
Target Python 3.10-3.12 and follow the existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep public captcha type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Internal checkpoint/ONNX artifact names use `threed_text`, `threed_rotate`, `threed_slider`, and `funcaptcha_rollball_animals`; solver artifacts are `gap_detector` and `rotation_regressor`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ONNX ops, and greedy CTC decoding for OCR models. `normal` uses `NORMAL_CHARS`, `math` uses `MATH_CHARS` and must be post-processed through `inference/math_eval.py`, and `3d_text` uses `THREED_CHARS`. `3d_rotate` and `3d_slider` output sigmoid values in `[0, 1]` and scale them with `REGRESSION_RANGE`; the rotate solver model outputs `(sin, cos)` on RGB input. The FunCaptcha matcher is a dual-input RGB Siamese model keyed by `task.question`, not by `captchaType`.
|
Target Python 3.10-3.12 and follow the existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep public captcha type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Internal checkpoint/ONNX artifact names use `threed_text`, `threed_rotate`, `threed_slider`, and `funcaptcha_rollball_animals`; solver artifacts are `gap_detector` and `rotation_regressor`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ONNX ops, and greedy CTC decoding for OCR models. `normal` uses `NORMAL_CHARS`, `math` uses `MATH_CHARS` and must be post-processed through `inference/math_eval.py`, and `3d_text` uses `THREED_CHARS`. `3d_rotate` and `3d_slider` output sigmoid values in `[0, 1]` and scale them with `REGRESSION_RANGE`; the rotate solver model outputs `(sin, cos)` on RGB input. The FunCaptcha matcher is a dual-input RGB Siamese model keyed by `task.question`, not by `captchaType`. Runtime ONNX artifacts belong under `onnx_models/`, not `models/`; external FunCaptcha ONNX files may omit metadata, in which case inference must preserve the external preprocessing contract instead of assuming the repo's centered RGB normalization.
|
||||||
- Do not casually upgrade `torch` or `torchvision`: newer CUDA 12.8 wheels in this repo's previous environment dropped `sm_61` kernels and failed on GTX 1050 Ti. Re-verify GPU execution before changing the pinned pair.
|
- Do not casually upgrade `torch` or `torchvision`: newer CUDA 12.8 wheels in this repo's previous environment dropped `sm_61` kernels and failed on GTX 1050 Ti. Re-verify GPU execution before changing the pinned pair.
|
||||||
|
|
||||||
## Training & Data Rules
|
## Training & Data Rules
|
||||||
|
|||||||
12
CLAUDE.md
12
CLAUDE.md
@@ -484,11 +484,23 @@ uv run python cli.py predict image.png --type normal # 跳过分类直接
|
|||||||
uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型
|
uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型
|
||||||
uv run python cli.py predict-dir ./test_images/ # 批量识别
|
uv run python cli.py predict-dir ./test_images/ # 批量识别
|
||||||
uv run python cli.py predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
uv run python cli.py predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||||
|
FUNCAPTCHA_ROLLBALL_MODEL_PATH=/path/to/4_3d_rollball_animals.onnx \
|
||||||
|
uv run python cli.py predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||||
|
|
||||||
# 启动 HTTP 服务 (需先安装 server 可选依赖)
|
# 启动 HTTP 服务 (需先安装 server 可选依赖)
|
||||||
uv run python cli.py serve --port 8080
|
uv run python cli.py serve --port 8080
|
||||||
```
|
```
|
||||||
|
|
||||||
|
FunCaptcha 专项 ONNX 查找顺序:
|
||||||
|
- `onnx_models/funcaptcha_rollball_animals.onnx`
|
||||||
|
- 环境变量 `FUNCAPTCHA_ROLLBALL_MODEL_PATH`
|
||||||
|
- 默认回退 `/mnt/data/code/python/funcaptcha-server/model/4_3d_rollball_animals.onnx`
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- `models/` 目录用于 Python 源码,不放 ONNX 产物
|
||||||
|
- 本仓库导出的 FunCaptcha ONNX 会携带 sidecar metadata 并使用 centered RGB 预处理
|
||||||
|
- 外部兼容 ONNX 若无 metadata,推理层需回退到 `funcaptcha-server` 的 `/255.0` RGB 预处理契约
|
||||||
|
|
||||||
## HTTP 服务 (server.py,可选)
|
## HTTP 服务 (server.py,可选)
|
||||||
|
|
||||||
纯推理服务,不依赖 torch / 训练代码,仅需 onnxruntime + FastAPI。
|
纯推理服务,不依赖 torch / 训练代码,仅需 onnxruntime + FastAPI。
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -143,6 +143,20 @@ uv run captcha export --model 4_3d_rollball_animals
|
|||||||
uv run captcha predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
uv run captcha predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||||
```
|
```
|
||||||
|
|
||||||
|
如果暂时没有训练数据,也可以直接复用外部 ONNX:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
FUNCAPTCHA_ROLLBALL_MODEL_PATH=/path/to/4_3d_rollball_animals.onnx \
|
||||||
|
uv run captcha predict-funcaptcha challenge.jpg --question 4_3d_rollball_animals
|
||||||
|
```
|
||||||
|
|
||||||
|
推理查找顺序为:
|
||||||
|
- `onnx_models/funcaptcha_rollball_animals.onnx`
|
||||||
|
- 环境变量 `FUNCAPTCHA_ROLLBALL_MODEL_PATH`
|
||||||
|
- 默认回退 `/mnt/data/code/python/funcaptcha-server/model/4_3d_rollball_animals.onnx`
|
||||||
|
|
||||||
|
不要把 ONNX 文件放到 `models/`;该目录用于 Python 模型定义源码,运行时模型产物应放在 `onnx_models/`。
|
||||||
|
|
||||||
## HTTP API
|
## HTTP API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -254,6 +254,10 @@ FUN_CAPTCHA_TASKS = {
|
|||||||
"num_candidates": 4,
|
"num_candidates": 4,
|
||||||
"answer_index_base": 0,
|
"answer_index_base": 0,
|
||||||
"channels": 3,
|
"channels": 3,
|
||||||
|
"external_model_env": "FUNCAPTCHA_ROLLBALL_MODEL_PATH",
|
||||||
|
"fallback_model_paths": [
|
||||||
|
str(PROJECT_ROOT.parent / "funcaptcha-server" / "model" / "4_3d_rollball_animals.onnx"),
|
||||||
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -205,6 +205,7 @@ def _load_and_export(model_name: str):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"task": "funcaptcha_siamese",
|
"task": "funcaptcha_siamese",
|
||||||
|
"preprocess": "rgb_centered",
|
||||||
"question": question,
|
"question": question,
|
||||||
"num_candidates": int(ckpt.get("num_candidates", task_cfg["num_candidates"])),
|
"num_candidates": int(ckpt.get("num_candidates", task_cfg["num_candidates"])),
|
||||||
"tile_size": list(ckpt.get("tile_size", task_cfg["tile_size"])),
|
"tile_size": list(ckpt.get("tile_size", task_cfg["tile_size"])),
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ FunCaptcha 专项 ONNX 推理。
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -32,9 +33,7 @@ class FunCaptchaRollballPipeline:
|
|||||||
self.question = question
|
self.question = question
|
||||||
self.task_cfg = FUN_CAPTCHA_TASKS[question]
|
self.task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||||
self.model_path = self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"
|
self.model_path = self._resolve_model_path()
|
||||||
if not self.model_path.exists():
|
|
||||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型: {self.model_path}")
|
|
||||||
|
|
||||||
opts = ort.SessionOptions()
|
opts = ort.SessionOptions()
|
||||||
opts.inter_op_num_threads = 1
|
opts.inter_op_num_threads = 1
|
||||||
@@ -45,6 +44,7 @@ class FunCaptchaRollballPipeline:
|
|||||||
providers=["CPUExecutionProvider"],
|
providers=["CPUExecutionProvider"],
|
||||||
)
|
)
|
||||||
self.metadata = load_model_metadata(self.model_path) or {}
|
self.metadata = load_model_metadata(self.model_path) or {}
|
||||||
|
self.preprocess_mode = self._resolve_preprocess_mode(self.metadata)
|
||||||
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
|
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
|
||||||
self.std = float(INFERENCE_CONFIG["normalize_std"])
|
self.std = float(INFERENCE_CONFIG["normalize_std"])
|
||||||
self.answer_index_base = int(
|
self.answer_index_base = int(
|
||||||
@@ -57,14 +57,12 @@ class FunCaptchaRollballPipeline:
|
|||||||
candidates, reference = self._split_challenge(challenge)
|
candidates, reference = self._split_challenge(challenge)
|
||||||
|
|
||||||
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
|
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
|
||||||
input_names = [inp.name for inp in self.session.get_inputs()]
|
input_defs = self.session.get_inputs()
|
||||||
|
input_names = [inp.name for inp in input_defs]
|
||||||
if len(input_names) != 2:
|
if len(input_names) != 2:
|
||||||
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
|
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
|
||||||
|
|
||||||
logits = self.session.run(None, {
|
logits = self._run_model(input_defs, input_names, candidates, ref_batch)
|
||||||
input_names[0]: candidates,
|
|
||||||
input_names[1]: ref_batch,
|
|
||||||
})[0].reshape(-1)
|
|
||||||
scores = 1.0 / (1.0 + np.exp(-logits))
|
scores = 1.0 / (1.0 + np.exp(-logits))
|
||||||
answer_idx = int(np.argmax(logits))
|
answer_idx = int(np.argmax(logits))
|
||||||
selected = answer_idx + self.answer_index_base
|
selected = answer_idx + self.answer_index_base
|
||||||
@@ -100,9 +98,18 @@ class FunCaptchaRollballPipeline:
|
|||||||
|
|
||||||
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||||
img_h, img_w = target_size
|
img_h, img_w = target_size
|
||||||
image = image.convert("RGB").resize((img_w, img_h), Image.BILINEAR)
|
image = image.convert("RGB")
|
||||||
|
if self.preprocess_mode == "rgb_centered":
|
||||||
|
image = image.resize((img_w, img_h), Image.BILINEAR)
|
||||||
|
elif self.preprocess_mode == "rgb_255":
|
||||||
|
# 对齐 funcaptcha-server 现有 ONNX 的预处理行为。
|
||||||
|
image = image.resize((img_w, img_h))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 FunCaptcha 预处理模式: {self.preprocess_mode}")
|
||||||
|
|
||||||
arr = np.asarray(image, dtype=np.float32) / 255.0
|
arr = np.asarray(image, dtype=np.float32) / 255.0
|
||||||
arr = (arr - self.mean) / self.std
|
if self.preprocess_mode == "rgb_centered":
|
||||||
|
arr = (arr - self.mean) / self.std
|
||||||
arr = np.transpose(arr, (2, 0, 1))
|
arr = np.transpose(arr, (2, 0, 1))
|
||||||
return arr.reshape(1, 3, img_h, img_w)
|
return arr.reshape(1, 3, img_h, img_w)
|
||||||
|
|
||||||
@@ -115,3 +122,69 @@ class FunCaptchaRollballPipeline:
|
|||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||||
|
|
||||||
|
def _resolve_model_path(self) -> Path:
|
||||||
|
candidates = [self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"]
|
||||||
|
|
||||||
|
env_name = self.task_cfg.get("external_model_env")
|
||||||
|
env_value = os.getenv(env_name) if env_name else None
|
||||||
|
if env_value:
|
||||||
|
candidates.append(Path(env_value).expanduser())
|
||||||
|
|
||||||
|
for fallback in self.task_cfg.get("fallback_model_paths", []):
|
||||||
|
candidates.append(Path(fallback).expanduser())
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate.exists():
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
tried = ", ".join(str(path) for path in candidates)
|
||||||
|
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型,已尝试: {tried}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_preprocess_mode(metadata: dict) -> str:
|
||||||
|
preprocess = metadata.get("preprocess")
|
||||||
|
if preprocess:
|
||||||
|
return str(preprocess)
|
||||||
|
if metadata.get("task") == "funcaptcha_siamese":
|
||||||
|
return "rgb_centered"
|
||||||
|
return "rgb_255"
|
||||||
|
|
||||||
|
def _run_model(
|
||||||
|
self,
|
||||||
|
input_defs,
|
||||||
|
input_names: list[str],
|
||||||
|
candidates: np.ndarray,
|
||||||
|
reference_batch: np.ndarray,
|
||||||
|
) -> np.ndarray:
|
||||||
|
batch_size = candidates.shape[0]
|
||||||
|
batch_axis = None
|
||||||
|
for input_def in input_defs:
|
||||||
|
shape = getattr(input_def, "shape", None)
|
||||||
|
if isinstance(shape, (list, tuple)) and shape:
|
||||||
|
batch_axis = shape[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
if batch_axis in (None, "batch", "None", -1) or not isinstance(batch_axis, int):
|
||||||
|
return self.session.run(None, {
|
||||||
|
input_names[0]: candidates,
|
||||||
|
input_names[1]: reference_batch,
|
||||||
|
})[0].reshape(-1)
|
||||||
|
|
||||||
|
if batch_axis == batch_size:
|
||||||
|
return self.session.run(None, {
|
||||||
|
input_names[0]: candidates,
|
||||||
|
input_names[1]: reference_batch,
|
||||||
|
})[0].reshape(-1)
|
||||||
|
|
||||||
|
if batch_axis != 1:
|
||||||
|
raise RuntimeError(f"专项模型不支持当前 batch 维度: expected={batch_axis} actual={batch_size}")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for idx in range(batch_size):
|
||||||
|
out = self.session.run(None, {
|
||||||
|
input_names[0]: candidates[idx:idx + 1],
|
||||||
|
input_names[1]: reference_batch[idx:idx + 1],
|
||||||
|
})[0]
|
||||||
|
outputs.append(np.asarray(out, dtype=np.float32).reshape(-1))
|
||||||
|
return np.concatenate(outputs, axis=0)
|
||||||
|
|||||||
@@ -69,18 +69,21 @@ class _FakeSessionOptions:
|
|||||||
|
|
||||||
|
|
||||||
class _FakeInput:
|
class _FakeInput:
|
||||||
def __init__(self, name):
|
def __init__(self, name, shape=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
|
||||||
class _FakeSession:
|
class _FakeSession:
|
||||||
def __init__(self, path, *args, **kwargs):
|
def __init__(self, path, *args, **kwargs):
|
||||||
self.path = path
|
self.path = path
|
||||||
|
self.last_feed_dict = None
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
return [_FakeInput("candidate"), _FakeInput("reference")]
|
return [_FakeInput("candidate"), _FakeInput("reference")]
|
||||||
|
|
||||||
def run(self, output_names, feed_dict):
|
def run(self, output_names, feed_dict):
|
||||||
|
self.last_feed_dict = feed_dict
|
||||||
batch_size = next(iter(feed_dict.values())).shape[0]
|
batch_size = next(iter(feed_dict.values())).shape[0]
|
||||||
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
||||||
if batch_size >= 3:
|
if batch_size >= 3:
|
||||||
@@ -93,6 +96,29 @@ class _FakeOrt:
|
|||||||
InferenceSession = _FakeSession
|
InferenceSession = _FakeSession
|
||||||
|
|
||||||
|
|
||||||
|
class _Batch1FakeSession(_FakeSession):
|
||||||
|
def __init__(self, path, *args, **kwargs):
|
||||||
|
super().__init__(path, *args, **kwargs)
|
||||||
|
self.run_calls = 0
|
||||||
|
|
||||||
|
def get_inputs(self):
|
||||||
|
shape = [1, 3, 48, 48]
|
||||||
|
return [_FakeInput("candidate", shape=shape), _FakeInput("reference", shape=shape)]
|
||||||
|
|
||||||
|
def run(self, output_names, feed_dict):
|
||||||
|
self.run_calls += 1
|
||||||
|
candidate = feed_dict["candidate"]
|
||||||
|
reference = feed_dict["reference"]
|
||||||
|
assert candidate.shape == (1, 3, 48, 48)
|
||||||
|
assert reference.shape == (1, 3, 48, 48)
|
||||||
|
return super().run(output_names, feed_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class _Batch1FakeOrt:
|
||||||
|
SessionOptions = _FakeSessionOptions
|
||||||
|
InferenceSession = _Batch1FakeSession
|
||||||
|
|
||||||
|
|
||||||
class TestFunCaptchaPipeline:
|
class TestFunCaptchaPipeline:
|
||||||
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
||||||
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
||||||
@@ -102,6 +128,7 @@ class TestFunCaptchaPipeline:
|
|||||||
{
|
{
|
||||||
"model_name": "funcaptcha_rollball_animals",
|
"model_name": "funcaptcha_rollball_animals",
|
||||||
"task": "funcaptcha_siamese",
|
"task": "funcaptcha_siamese",
|
||||||
|
"preprocess": "rgb_centered",
|
||||||
"question": "4_3d_rollball_animals",
|
"question": "4_3d_rollball_animals",
|
||||||
"num_candidates": 4,
|
"num_candidates": 4,
|
||||||
"tile_size": [200, 200],
|
"tile_size": [200, 200],
|
||||||
@@ -121,3 +148,40 @@ class TestFunCaptchaPipeline:
|
|||||||
assert result["objects"] == [2]
|
assert result["objects"] == [2]
|
||||||
assert result["result"] == "2"
|
assert result["result"] == "2"
|
||||||
assert len(result["scores"]) == 4
|
assert len(result["scores"]) == 4
|
||||||
|
assert pipeline.preprocess_mode == "rgb_centered"
|
||||||
|
|
||||||
|
def test_pipeline_uses_external_model_env_without_metadata(self, tmp_path, monkeypatch):
|
||||||
|
external_model = tmp_path / "external_rollball.onnx"
|
||||||
|
external_model.touch()
|
||||||
|
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
||||||
|
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
||||||
|
|
||||||
|
image = Image.new("RGB", (800, 400), color=(128, 128, 128))
|
||||||
|
sample_path = tmp_path / "0_demo.png"
|
||||||
|
image.save(sample_path)
|
||||||
|
|
||||||
|
empty_models_dir = tmp_path / "missing_models"
|
||||||
|
pipeline = FunCaptchaRollballPipeline(models_dir=empty_models_dir)
|
||||||
|
result = pipeline.solve(sample_path)
|
||||||
|
|
||||||
|
assert result["objects"] == [2]
|
||||||
|
assert pipeline.model_path == external_model
|
||||||
|
assert pipeline.preprocess_mode == "rgb_255"
|
||||||
|
candidate = pipeline.session.last_feed_dict["candidate"]
|
||||||
|
assert candidate.shape == (4, 3, 48, 48)
|
||||||
|
assert candidate[0, 0, 0, 0] == pytest.approx(128 / 255.0, abs=1e-6)
|
||||||
|
|
||||||
|
def test_pipeline_handles_external_fixed_batch_model(self, tmp_path, monkeypatch):
|
||||||
|
external_model = tmp_path / "external_rollball.onnx"
|
||||||
|
external_model.touch()
|
||||||
|
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
||||||
|
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _Batch1FakeOrt)
|
||||||
|
|
||||||
|
sample_path = tmp_path / "0_demo.png"
|
||||||
|
_build_rollball_image(sample_path, answer_idx=0)
|
||||||
|
|
||||||
|
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path / "missing_models")
|
||||||
|
result = pipeline.solve(sample_path)
|
||||||
|
|
||||||
|
assert result["objects"] == [0]
|
||||||
|
assert pipeline.session.run_calls == 4
|
||||||
|
|||||||
Reference in New Issue
Block a user