Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -0,0 +1,109 @@
"""
测试合成数据指纹与 ONNX metadata 辅助逻辑。
"""
from pathlib import Path
from PIL import Image
from inference.model_metadata import load_model_metadata, write_model_metadata
from training.data_fingerprint import (
build_dataset_spec,
ensure_synthetic_dataset,
load_dataset_manifest,
)
class AdoptableGenerator:
def generate_dataset(self, num_samples: int, output_dir: str):
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for idx in range(num_samples):
Image.new("L", (2, 2), color=255).save(out_dir / f"kept_{idx:06d}.png")
class RefreshingGenerator:
def generate_dataset(self, num_samples: int, output_dir: str):
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for idx in range(num_samples):
Image.new("L", (2, 2), color=255).save(out_dir / f"20÷4_{idx:06d}.png")
class TestDataFingerprint:
def test_adopt_existing_dataset_when_manifest_missing(self, tmp_path):
dataset_dir = tmp_path / "normal"
dataset_dir.mkdir()
for idx, label in enumerate(["AB12", "CD34"]):
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
state = ensure_synthetic_dataset(
dataset_dir,
generator_cls=AdoptableGenerator,
spec=build_dataset_spec(
AdoptableGenerator,
config_key="normal",
config_snapshot={"chars": "ABCD1234"},
),
gen_count=2,
exact_count=2,
adopt_if_missing=True,
)
manifest = load_dataset_manifest(dataset_dir)
assert state["adopted"] is True
assert state["refreshed"] is False
assert manifest is not None
assert manifest["adopted_existing"] is True
assert manifest["sample_count"] == 2
def test_refresh_dataset_when_validator_fails(self, tmp_path):
dataset_dir = tmp_path / "math"
dataset_dir.mkdir()
for idx, label in enumerate(["1+1", "2-1"]):
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
state = ensure_synthetic_dataset(
dataset_dir,
generator_cls=RefreshingGenerator,
spec=build_dataset_spec(
RefreshingGenerator,
config_key="math",
config_snapshot={"operators": ["+", "-", "÷"]},
),
gen_count=2,
exact_count=2,
validator=lambda files: any("÷" in path.stem for path in files),
adopt_if_missing=True,
)
files = sorted(dataset_dir.glob("*.png"))
manifest = load_dataset_manifest(dataset_dir)
assert state["refreshed"] is True
assert state["adopted"] is False
assert manifest is not None
assert manifest["adopted_existing"] is False
assert len(files) == 2
assert all("÷" in path.stem for path in files)
class TestModelMetadata:
def test_write_and_load_model_metadata(self, tmp_path):
model_path = tmp_path / "normal.onnx"
model_path.touch()
write_model_metadata(
model_path,
{
"model_name": "normal",
"task": "ctc",
"chars": "ABC",
"input_shape": [1, 40, 120],
},
)
metadata = load_model_metadata(model_path)
assert metadata is not None
assert metadata["version"] == 1
assert metadata["chars"] == "ABC"
assert metadata["task"] == "ctc"

123
tests/test_funcaptcha.py Normal file
View File

@@ -0,0 +1,123 @@
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE
import inference.fun_captcha as fun_module
from inference.fun_captcha import FunCaptchaRollballPipeline
from inference.model_metadata import write_model_metadata
from models.fun_captcha_siamese import FunCaptchaSiamese
from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform
def _build_rollball_image(path: Path, answer_idx: int = 2):
colors = [
(255, 80, 80),
(80, 255, 80),
(80, 80, 255),
(255, 220, 80),
]
image = Image.new("RGB", (800, 400), color=(245, 245, 245))
for idx, color in enumerate(colors):
tile = Image.new("RGB", (200, 200), color=color)
image.paste(tile, (idx * 200, 0))
reference = Image.new("RGB", (200, 200), color=colors[answer_idx])
image.paste(reference, (0, 200))
image.save(path)
class TestFunCaptchaChallengeDataset:
def test_dataset_splits_candidates_and_reference(self, tmp_path):
sample_path = tmp_path / "2_demo.png"
_build_rollball_image(sample_path, answer_idx=2)
dataset = FunCaptchaChallengeDataset(
dirs=[tmp_path],
task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"],
transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]),
)
candidates, reference, answer_idx = dataset[0]
assert candidates.shape == (4, 3, 48, 48)
assert reference.shape == (3, 48, 48)
assert int(answer_idx.item()) == 2
class TestFunCaptchaSiamese:
def test_forward_shape(self):
model = FunCaptchaSiamese()
model.eval()
candidate = torch.randn(5, 3, 48, 48)
reference = torch.randn(5, 3, 48, 48)
out = model(candidate, reference)
assert out.shape == (5, 1)
def test_param_count_reasonable(self):
model = FunCaptchaSiamese()
n = sum(p.numel() for p in model.parameters())
assert n < 450_000, f"too many params: {n}"
class _FakeSessionOptions:
def __init__(self):
self.inter_op_num_threads = 0
self.intra_op_num_threads = 0
class _FakeInput:
def __init__(self, name):
self.name = name
class _FakeSession:
def __init__(self, path, *args, **kwargs):
self.path = path
def get_inputs(self):
return [_FakeInput("candidate"), _FakeInput("reference")]
def run(self, output_names, feed_dict):
batch_size = next(iter(feed_dict.values())).shape[0]
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
if batch_size >= 3:
logits[2, 0] = 0.95
return [logits]
class _FakeOrt:
SessionOptions = _FakeSessionOptions
InferenceSession = _FakeSession
class TestFunCaptchaPipeline:
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
model_path.touch()
write_model_metadata(
model_path,
{
"model_name": "funcaptcha_rollball_animals",
"task": "funcaptcha_siamese",
"question": "4_3d_rollball_animals",
"num_candidates": 4,
"tile_size": [200, 200],
"reference_box": [0, 200, 200, 400],
"answer_index_base": 0,
"input_shape": [3, 48, 48],
},
)
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
sample_path = tmp_path / "1_demo.png"
_build_rollball_image(sample_path, answer_idx=1)
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path)
result = pipeline.solve(sample_path)
assert result["question"] == "4_3d_rollball_animals"
assert result["objects"] == [2]
assert result["result"] == "2"
assert len(result["scores"]) == 4

View File

@@ -9,7 +9,14 @@ import re
import pytest
from PIL import Image
from config import GENERATE_CONFIG, NORMAL_CHARS, MATH_CHARS, THREED_CHARS, SOLVER_CONFIG
from config import (
GENERATE_CONFIG,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
SOLVER_CONFIG,
SOLVER_REGRESSION_RANGE,
)
from generators import (
NormalCaptchaGenerator,
MathCaptchaGenerator,
@@ -67,6 +74,10 @@ class TestMathCaptchaGenerator:
img, label = self.gen.generate()
assert re.match(r"^\d+[+\-×÷]\d+$", label), f"unexpected label format: {label!r}"
def test_generate_with_division_text(self):
img, label = self.gen.generate(text="20÷4")
assert label == "20÷4"
class TestThreeDCaptchaGenerator:
def setup_method(self):
@@ -150,7 +161,25 @@ class TestSlideDataGenerator:
def test_label_is_numeric(self):
img, label = self.gen.generate()
val = int(label)
assert val >= 0
gs = self.gen.gap_size
margin = gs + 10
assert margin + gs // 2 <= val <= self.gen.width - margin + gs // 2
def test_labels_normalize_inside_solver_range(self, tmp_path):
for idx in range(3):
img, label = self.gen.generate()
img.save(tmp_path / f"{label}_{idx:06d}.png")
from training.dataset import RegressionDataset
ds = RegressionDataset(
dirs=[tmp_path],
label_range=SOLVER_REGRESSION_RANGE["slide"],
transform=None,
)
assert len(ds.samples) == 3
for _, norm in ds.samples:
assert 0.0 < norm < 1.0
class TestRotateSolverDataGenerator:

View File

@@ -8,11 +8,15 @@
"""
import math
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
from inference.math_eval import eval_captcha_math
from inference.model_metadata import write_model_metadata
import inference.pipeline as pipeline_module
from inference.pipeline import CaptchaPipeline
@@ -97,6 +101,96 @@ class TestCTCGreedyDecode:
assert result == "AA"
class _FakeInput:
name = "input"
class _FakeSession:
def __init__(self, path, *args, **kwargs):
self.model_name = Path(path).name
def get_inputs(self):
return [_FakeInput()]
def run(self, output_names, feed_dict):
if self.model_name == "classifier.onnx":
return [np.array([[0.1, 0.9]], dtype=np.float32)]
if self.model_name == "normal.onnx":
logits = np.full((2, 1, 4), -10.0, dtype=np.float32)
logits[0, 0, 2] = 10.0
logits[1, 0, 0] = 10.0
return [logits]
if self.model_name == "threed_rotate.onnx":
return [np.array([[0.25]], dtype=np.float32)]
raise AssertionError(f"unexpected fake session: {self.model_name}")
class _FakeSessionOptions:
def __init__(self):
self.inter_op_num_threads = 0
self.intra_op_num_threads = 0
class _FakeOrt:
SessionOptions = _FakeSessionOptions
InferenceSession = _FakeSession
class TestPipelineMetadata:
def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch):
(tmp_path / "classifier.onnx").touch()
write_model_metadata(
tmp_path / "classifier.onnx",
{
"model_name": "classifier",
"task": "classifier",
"class_names": ["math", "normal"],
"input_shape": [1, 64, 128],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white"))
assert captcha_type == "normal"
def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch):
(tmp_path / "normal.onnx").touch()
write_model_metadata(
tmp_path / "normal.onnx",
{
"model_name": "normal",
"task": "ctc",
"chars": "XYZ",
"input_shape": [1, 40, 120],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal")
assert result["raw"] == "Y"
assert result["result"] == "Y"
def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch):
(tmp_path / "threed_rotate.onnx").touch()
write_model_metadata(
tmp_path / "threed_rotate.onnx",
{
"model_name": "threed_rotate",
"task": "regression",
"label_range": [100, 200],
"input_shape": [1, 80, 80],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate")
assert result["raw"] == "125.0"
assert result["result"] == "125"
# ============================================================
# SlideSolver 测试
# ============================================================

423
tests/test_server.py Normal file
View File

@@ -0,0 +1,423 @@
import asyncio
import base64
import time
from pathlib import Path
from types import SimpleNamespace
from urllib.error import URLError
from urllib.parse import urlencode
import pytest
pytest.importorskip("fastapi")
from fastapi.responses import JSONResponse
from config import SERVER_CONFIG
import server as server_module
from server import create_app
class _FakePipeline:
def solve(self, image, captcha_type=None):
return {
"type": captcha_type or "normal",
"result": "A3B8",
"raw": "A3B8",
"time_ms": 1.23,
}
class _FakeFunPipeline:
def solve(self, image):
return {
"type": "funcaptcha",
"question": "4_3d_rollball_animals",
"objects": [2],
"result": "2",
"raw": "2",
"time_ms": 2.34,
}
def _create_test_app(funcaptcha_factories=None):
return create_app(
pipeline_factory=_FakePipeline,
funcaptcha_factories=funcaptcha_factories,
)
def _get_route(app, path: str):
for route in app.routes:
if getattr(route, "path", None) == path:
return route.endpoint
raise AssertionError(f"route not found: {path}")
def _fake_request(host: str = "127.0.0.1"):
return SimpleNamespace(client=SimpleNamespace(host=host))
@pytest.fixture(autouse=True)
def _reset_server_config(monkeypatch, tmp_path):
monkeypatch.setitem(SERVER_CONFIG, "client_key", None)
monkeypatch.setitem(SERVER_CONFIG, "tasks_dir", str(tmp_path / "server_tasks"))
monkeypatch.setitem(SERVER_CONFIG, "task_cost", 0.0)
monkeypatch.setitem(SERVER_CONFIG, "callback_max_retries", 2)
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_delay_seconds", 1.0)
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_backoff", 2.0)
monkeypatch.setitem(SERVER_CONFIG, "callback_signing_secret", None)
def test_solve_base64_returns_sync_payload():
app = _create_test_app()
solve = _get_route(app, "/api/v1/solve")
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
response = asyncio.run(
solve(SimpleNamespace(image=encoded, type="math"))
)
assert response == {
"type": "math",
"result": "A3B8",
"raw": "A3B8",
"time_ms": 1.23,
}
def test_create_task_and_get_task_result():
app = _create_test_app()
create_task = _get_route(app, "/createTask")
get_task_result = _get_route(app, "/getTaskResult")
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
create_response = asyncio.run(
create_task(
SimpleNamespace(
clientKey="local",
task=SimpleNamespace(
type="ImageToTextTaskM1",
body=encoded,
image=None,
captchaType="normal",
),
),
_fake_request("10.0.0.8"),
)
)
assert create_response["errorId"] == 0
assert create_response["status"] == "processing"
assert isinstance(create_response["createTime"], int)
assert isinstance(create_response["expiresAt"], int)
task_id = create_response["taskId"]
result_response = None
for _ in range(20):
result_response = asyncio.run(
get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
)
if result_response.get("status") == "ready":
break
assert result_response is not None
assert result_response["errorId"] == 0
assert result_response["status"] == "ready"
assert result_response["solution"] == {
"text": "A3B8",
"answer": "A3B8",
"raw": "A3B8",
"captchaType": "normal",
"timeMs": 1.23,
}
assert result_response["cost"] == "0.00000"
assert result_response["ip"] == "10.0.0.8"
assert result_response["solveCount"] == 1
assert result_response["task"] == {
"type": "ImageToTextTaskM1",
"captchaType": "normal",
}
assert result_response["callback"] == {
"configured": False,
"url": None,
"attempts": 0,
"delivered": False,
"deliveredAt": None,
"lastError": None,
}
assert isinstance(result_response["expiresAt"], int)
def test_get_task_result_returns_not_found_for_unknown_task():
app = _create_test_app()
get_task_result = _get_route(app, "/getTaskResult")
response = asyncio.run(
get_task_result(SimpleNamespace(clientKey="local", taskId="missing-task"))
)
assert response["errorCode"] == "ERROR_TASK_NOT_FOUND"
def test_solve_returns_json_error_for_invalid_base64():
app = _create_test_app()
solve = _get_route(app, "/solve")
response = asyncio.run(
solve(SimpleNamespace(image="not_base64!", type="normal"))
)
assert isinstance(response, JSONResponse)
assert response.status_code == 400
def test_health_alias_reports_client_key_flag(monkeypatch):
app = _create_test_app()
health = _get_route(app, "/api/v1/health")
monkeypatch.setitem(SERVER_CONFIG, "client_key", "secret")
response = health()
assert response["status"] == "ok"
assert response["client_key_required"] is True
assert "ImageToTextTask" in response["supported_task_types"]
def test_client_key_is_required_for_task_api(monkeypatch):
app = _create_test_app()
create_task = _get_route(app, "/api/v1/createTask")
get_balance = _get_route(app, "/api/v1/getBalance")
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
monkeypatch.setitem(SERVER_CONFIG, "client_key", "secret")
create_response = asyncio.run(
create_task(
SimpleNamespace(
clientKey="wrong-key",
task=SimpleNamespace(
type="ImageToTextTask",
body=encoded,
image=None,
captchaType="normal",
),
),
_fake_request(),
)
)
balance_response = asyncio.run(
get_balance(SimpleNamespace(clientKey="wrong-key"))
)
assert create_response["errorCode"] == "ERROR_KEY_DOES_NOT_EXIST"
assert balance_response["errorCode"] == "ERROR_KEY_DOES_NOT_EXIST"
def test_create_task_triggers_callback(monkeypatch):
app = _create_test_app()
create_task = _get_route(app, "/createTask")
callbacks = []
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
def _fake_post_callback(callback_url, payload):
callbacks.append((callback_url, payload))
monkeypatch.setattr(server_module, "_post_callback", _fake_post_callback)
response = asyncio.run(
create_task(
SimpleNamespace(
clientKey="local",
callbackUrl="https://example.com/callback",
task=SimpleNamespace(
type="ImageToTextTask",
body=encoded,
image=None,
captchaType="normal",
),
),
_fake_request("10.0.0.9"),
)
)
assert response["errorId"] == 0
for _ in range(20):
if callbacks:
break
time.sleep(0.01)
assert callbacks == [
(
"https://example.com/callback",
{
"id": response["taskId"],
"taskId": response["taskId"],
"status": "ready",
"errorId": "0",
"code": "A3B8",
"text": "A3B8",
"answer": "A3B8",
"raw": "A3B8",
"captchaType": "normal",
"timeMs": "1.23",
"cost": "0.00000",
},
)
]
def test_create_task_routes_fun_captcha_question():
app = _create_test_app(
funcaptcha_factories={"4_3d_rollball_animals": _FakeFunPipeline}
)
create_task = _get_route(app, "/createTask")
get_task_result = _get_route(app, "/getTaskResult")
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
create_response = asyncio.run(
create_task(
SimpleNamespace(
clientKey="local",
task=SimpleNamespace(
type="FunCaptcha",
body=encoded,
image=None,
captchaType=None,
question="4_3d_rollball_animals",
),
),
_fake_request("10.0.0.7"),
)
)
task_id = create_response["taskId"]
result_response = None
for _ in range(20):
result_response = asyncio.run(
get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
)
if result_response.get("status") == "ready":
break
assert result_response["errorId"] == 0
assert result_response["status"] == "ready"
assert result_response["solution"] == {
"objects": [2],
"answer": 2,
"raw": "2",
"timeMs": 2.34,
"question": "4_3d_rollball_animals",
"text": "2",
}
assert result_response["task"] == {
"type": "FunCaptcha",
"captchaType": None,
"question": "4_3d_rollball_animals",
}
def test_create_task_retries_callback(monkeypatch):
app = _create_test_app()
create_task = _get_route(app, "/createTask")
attempts = []
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
def _flaky_post_callback(callback_url, payload):
attempts.append((callback_url, payload["taskId"]))
if len(attempts) < 3:
raise URLError("temporary failure")
monkeypatch.setattr(server_module, "_post_callback", _flaky_post_callback)
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_delay_seconds", 0.0)
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_backoff", 1.0)
response = asyncio.run(
create_task(
SimpleNamespace(
clientKey="local",
callbackUrl="https://example.com/callback",
task=SimpleNamespace(
type="ImageToTextTask",
body=encoded,
image=None,
captchaType="normal",
),
),
_fake_request(),
)
)
for _ in range(20):
if len(attempts) >= 3:
break
time.sleep(0.01)
assert response["errorId"] == 0
assert response["status"] == "processing"
assert attempts == [
("https://example.com/callback", response["taskId"]),
("https://example.com/callback", response["taskId"]),
("https://example.com/callback", response["taskId"]),
]
def test_tasks_are_restored_from_disk():
app = _create_test_app()
create_task = _get_route(app, "/createTask")
get_task_result = _get_route(app, "/getTaskResult")
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
create_response = asyncio.run(
create_task(
SimpleNamespace(
clientKey="local",
task=SimpleNamespace(
type="ImageToTextTask",
body=encoded,
image=None,
captchaType="normal",
),
),
_fake_request(),
)
)
task_id = create_response["taskId"]
for _ in range(20):
result = asyncio.run(
get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
)
if result.get("status") == "ready":
break
time.sleep(0.01)
task_file = Path(SERVER_CONFIG["tasks_dir"]) / f"{task_id}.json"
assert task_file.exists()
reloaded_app = _create_test_app()
reloaded_get_task_result = _get_route(reloaded_app, "/getTaskResult")
reloaded_result = asyncio.run(
reloaded_get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
)
assert reloaded_result["status"] == "ready"
assert reloaded_result["solution"]["answer"] == "A3B8"
def test_callback_request_includes_signature_headers(monkeypatch):
monkeypatch.setitem(SERVER_CONFIG, "callback_signing_secret", "shared-secret")
request = server_module._build_callback_request(
"https://example.com/callback",
{"taskId": "abc", "status": "ready"},
)
body = urlencode({"taskId": "abc", "status": "ready"}).encode("utf-8")
headers = {key.lower(): value for key, value in request.header_items()}
timestamp = headers["x-captchabreaker-timestamp"]
signature = headers["x-captchabreaker-signature"]
assert headers["content-type"] == "application/x-www-form-urlencoded"
assert headers["x-captchabreaker-signature-alg"] == "hmac-sha256"
assert signature == server_module._sign_callback_payload(body, timestamp, "shared-secret")