Align task API and add FunCaptcha support
This commit is contained in:
109
tests/test_data_fingerprint.py
Normal file
109
tests/test_data_fingerprint.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
测试合成数据指纹与 ONNX metadata 辅助逻辑。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from inference.model_metadata import load_model_metadata, write_model_metadata
|
||||
from training.data_fingerprint import (
|
||||
build_dataset_spec,
|
||||
ensure_synthetic_dataset,
|
||||
load_dataset_manifest,
|
||||
)
|
||||
|
||||
|
||||
class AdoptableGenerator:
|
||||
def generate_dataset(self, num_samples: int, output_dir: str):
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for idx in range(num_samples):
|
||||
Image.new("L", (2, 2), color=255).save(out_dir / f"kept_{idx:06d}.png")
|
||||
|
||||
|
||||
class RefreshingGenerator:
|
||||
def generate_dataset(self, num_samples: int, output_dir: str):
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for idx in range(num_samples):
|
||||
Image.new("L", (2, 2), color=255).save(out_dir / f"20÷4_{idx:06d}.png")
|
||||
|
||||
|
||||
class TestDataFingerprint:
|
||||
def test_adopt_existing_dataset_when_manifest_missing(self, tmp_path):
|
||||
dataset_dir = tmp_path / "normal"
|
||||
dataset_dir.mkdir()
|
||||
for idx, label in enumerate(["AB12", "CD34"]):
|
||||
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
|
||||
|
||||
state = ensure_synthetic_dataset(
|
||||
dataset_dir,
|
||||
generator_cls=AdoptableGenerator,
|
||||
spec=build_dataset_spec(
|
||||
AdoptableGenerator,
|
||||
config_key="normal",
|
||||
config_snapshot={"chars": "ABCD1234"},
|
||||
),
|
||||
gen_count=2,
|
||||
exact_count=2,
|
||||
adopt_if_missing=True,
|
||||
)
|
||||
|
||||
manifest = load_dataset_manifest(dataset_dir)
|
||||
assert state["adopted"] is True
|
||||
assert state["refreshed"] is False
|
||||
assert manifest is not None
|
||||
assert manifest["adopted_existing"] is True
|
||||
assert manifest["sample_count"] == 2
|
||||
|
||||
def test_refresh_dataset_when_validator_fails(self, tmp_path):
|
||||
dataset_dir = tmp_path / "math"
|
||||
dataset_dir.mkdir()
|
||||
for idx, label in enumerate(["1+1", "2-1"]):
|
||||
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
|
||||
|
||||
state = ensure_synthetic_dataset(
|
||||
dataset_dir,
|
||||
generator_cls=RefreshingGenerator,
|
||||
spec=build_dataset_spec(
|
||||
RefreshingGenerator,
|
||||
config_key="math",
|
||||
config_snapshot={"operators": ["+", "-", "÷"]},
|
||||
),
|
||||
gen_count=2,
|
||||
exact_count=2,
|
||||
validator=lambda files: any("÷" in path.stem for path in files),
|
||||
adopt_if_missing=True,
|
||||
)
|
||||
|
||||
files = sorted(dataset_dir.glob("*.png"))
|
||||
manifest = load_dataset_manifest(dataset_dir)
|
||||
assert state["refreshed"] is True
|
||||
assert state["adopted"] is False
|
||||
assert manifest is not None
|
||||
assert manifest["adopted_existing"] is False
|
||||
assert len(files) == 2
|
||||
assert all("÷" in path.stem for path in files)
|
||||
|
||||
|
||||
class TestModelMetadata:
|
||||
def test_write_and_load_model_metadata(self, tmp_path):
|
||||
model_path = tmp_path / "normal.onnx"
|
||||
model_path.touch()
|
||||
|
||||
write_model_metadata(
|
||||
model_path,
|
||||
{
|
||||
"model_name": "normal",
|
||||
"task": "ctc",
|
||||
"chars": "ABC",
|
||||
"input_shape": [1, 40, 120],
|
||||
},
|
||||
)
|
||||
|
||||
metadata = load_model_metadata(model_path)
|
||||
assert metadata is not None
|
||||
assert metadata["version"] == 1
|
||||
assert metadata["chars"] == "ABC"
|
||||
assert metadata["task"] == "ctc"
|
||||
123
tests/test_funcaptcha.py
Normal file
123
tests/test_funcaptcha.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE
|
||||
import inference.fun_captcha as fun_module
|
||||
from inference.fun_captcha import FunCaptchaRollballPipeline
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||||
from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform
|
||||
|
||||
|
||||
def _build_rollball_image(path: Path, answer_idx: int = 2):
|
||||
colors = [
|
||||
(255, 80, 80),
|
||||
(80, 255, 80),
|
||||
(80, 80, 255),
|
||||
(255, 220, 80),
|
||||
]
|
||||
image = Image.new("RGB", (800, 400), color=(245, 245, 245))
|
||||
for idx, color in enumerate(colors):
|
||||
tile = Image.new("RGB", (200, 200), color=color)
|
||||
image.paste(tile, (idx * 200, 0))
|
||||
|
||||
reference = Image.new("RGB", (200, 200), color=colors[answer_idx])
|
||||
image.paste(reference, (0, 200))
|
||||
image.save(path)
|
||||
|
||||
|
||||
class TestFunCaptchaChallengeDataset:
|
||||
def test_dataset_splits_candidates_and_reference(self, tmp_path):
|
||||
sample_path = tmp_path / "2_demo.png"
|
||||
_build_rollball_image(sample_path, answer_idx=2)
|
||||
|
||||
dataset = FunCaptchaChallengeDataset(
|
||||
dirs=[tmp_path],
|
||||
task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"],
|
||||
transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]),
|
||||
)
|
||||
|
||||
candidates, reference, answer_idx = dataset[0]
|
||||
assert candidates.shape == (4, 3, 48, 48)
|
||||
assert reference.shape == (3, 48, 48)
|
||||
assert int(answer_idx.item()) == 2
|
||||
|
||||
|
||||
class TestFunCaptchaSiamese:
|
||||
def test_forward_shape(self):
|
||||
model = FunCaptchaSiamese()
|
||||
model.eval()
|
||||
candidate = torch.randn(5, 3, 48, 48)
|
||||
reference = torch.randn(5, 3, 48, 48)
|
||||
out = model(candidate, reference)
|
||||
assert out.shape == (5, 1)
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
model = FunCaptchaSiamese()
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
assert n < 450_000, f"too many params: {n}"
|
||||
|
||||
|
||||
class _FakeSessionOptions:
|
||||
def __init__(self):
|
||||
self.inter_op_num_threads = 0
|
||||
self.intra_op_num_threads = 0
|
||||
|
||||
|
||||
class _FakeInput:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.path = path
|
||||
|
||||
def get_inputs(self):
|
||||
return [_FakeInput("candidate"), _FakeInput("reference")]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
batch_size = next(iter(feed_dict.values())).shape[0]
|
||||
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
||||
if batch_size >= 3:
|
||||
logits[2, 0] = 0.95
|
||||
return [logits]
|
||||
|
||||
|
||||
class _FakeOrt:
|
||||
SessionOptions = _FakeSessionOptions
|
||||
InferenceSession = _FakeSession
|
||||
|
||||
|
||||
class TestFunCaptchaPipeline:
|
||||
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
||||
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
||||
model_path.touch()
|
||||
write_model_metadata(
|
||||
model_path,
|
||||
{
|
||||
"model_name": "funcaptcha_rollball_animals",
|
||||
"task": "funcaptcha_siamese",
|
||||
"question": "4_3d_rollball_animals",
|
||||
"num_candidates": 4,
|
||||
"tile_size": [200, 200],
|
||||
"reference_box": [0, 200, 200, 400],
|
||||
"answer_index_base": 0,
|
||||
"input_shape": [3, 48, 48],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
sample_path = tmp_path / "1_demo.png"
|
||||
_build_rollball_image(sample_path, answer_idx=1)
|
||||
|
||||
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path)
|
||||
result = pipeline.solve(sample_path)
|
||||
assert result["question"] == "4_3d_rollball_animals"
|
||||
assert result["objects"] == [2]
|
||||
assert result["result"] == "2"
|
||||
assert len(result["scores"]) == 4
|
||||
@@ -9,7 +9,14 @@ import re
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from config import GENERATE_CONFIG, NORMAL_CHARS, MATH_CHARS, THREED_CHARS, SOLVER_CONFIG
|
||||
from config import (
|
||||
GENERATE_CONFIG,
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
)
|
||||
from generators import (
|
||||
NormalCaptchaGenerator,
|
||||
MathCaptchaGenerator,
|
||||
@@ -67,6 +74,10 @@ class TestMathCaptchaGenerator:
|
||||
img, label = self.gen.generate()
|
||||
assert re.match(r"^\d+[+\-×÷]\d+$", label), f"unexpected label format: {label!r}"
|
||||
|
||||
def test_generate_with_division_text(self):
|
||||
img, label = self.gen.generate(text="20÷4")
|
||||
assert label == "20÷4"
|
||||
|
||||
|
||||
class TestThreeDCaptchaGenerator:
|
||||
def setup_method(self):
|
||||
@@ -150,7 +161,25 @@ class TestSlideDataGenerator:
|
||||
def test_label_is_numeric(self):
|
||||
img, label = self.gen.generate()
|
||||
val = int(label)
|
||||
assert val >= 0
|
||||
gs = self.gen.gap_size
|
||||
margin = gs + 10
|
||||
assert margin + gs // 2 <= val <= self.gen.width - margin + gs // 2
|
||||
|
||||
def test_labels_normalize_inside_solver_range(self, tmp_path):
|
||||
for idx in range(3):
|
||||
img, label = self.gen.generate()
|
||||
img.save(tmp_path / f"{label}_{idx:06d}.png")
|
||||
|
||||
from training.dataset import RegressionDataset
|
||||
|
||||
ds = RegressionDataset(
|
||||
dirs=[tmp_path],
|
||||
label_range=SOLVER_REGRESSION_RANGE["slide"],
|
||||
transform=None,
|
||||
)
|
||||
assert len(ds.samples) == 3
|
||||
for _, norm in ds.samples:
|
||||
assert 0.0 < norm < 1.0
|
||||
|
||||
|
||||
class TestRotateSolverDataGenerator:
|
||||
|
||||
@@ -8,11 +8,15 @@
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.model_metadata import write_model_metadata
|
||||
import inference.pipeline as pipeline_module
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
|
||||
|
||||
@@ -97,6 +101,96 @@ class TestCTCGreedyDecode:
|
||||
assert result == "AA"
|
||||
|
||||
|
||||
class _FakeInput:
|
||||
name = "input"
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.model_name = Path(path).name
|
||||
|
||||
def get_inputs(self):
|
||||
return [_FakeInput()]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
if self.model_name == "classifier.onnx":
|
||||
return [np.array([[0.1, 0.9]], dtype=np.float32)]
|
||||
if self.model_name == "normal.onnx":
|
||||
logits = np.full((2, 1, 4), -10.0, dtype=np.float32)
|
||||
logits[0, 0, 2] = 10.0
|
||||
logits[1, 0, 0] = 10.0
|
||||
return [logits]
|
||||
if self.model_name == "threed_rotate.onnx":
|
||||
return [np.array([[0.25]], dtype=np.float32)]
|
||||
raise AssertionError(f"unexpected fake session: {self.model_name}")
|
||||
|
||||
|
||||
class _FakeSessionOptions:
|
||||
def __init__(self):
|
||||
self.inter_op_num_threads = 0
|
||||
self.intra_op_num_threads = 0
|
||||
|
||||
|
||||
class _FakeOrt:
|
||||
SessionOptions = _FakeSessionOptions
|
||||
InferenceSession = _FakeSession
|
||||
|
||||
|
||||
class TestPipelineMetadata:
|
||||
def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "classifier.onnx").touch()
|
||||
write_model_metadata(
|
||||
tmp_path / "classifier.onnx",
|
||||
{
|
||||
"model_name": "classifier",
|
||||
"task": "classifier",
|
||||
"class_names": ["math", "normal"],
|
||||
"input_shape": [1, 64, 128],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||||
captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white"))
|
||||
assert captcha_type == "normal"
|
||||
|
||||
def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "normal.onnx").touch()
|
||||
write_model_metadata(
|
||||
tmp_path / "normal.onnx",
|
||||
{
|
||||
"model_name": "normal",
|
||||
"task": "ctc",
|
||||
"chars": "XYZ",
|
||||
"input_shape": [1, 40, 120],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||||
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal")
|
||||
assert result["raw"] == "Y"
|
||||
assert result["result"] == "Y"
|
||||
|
||||
def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "threed_rotate.onnx").touch()
|
||||
write_model_metadata(
|
||||
tmp_path / "threed_rotate.onnx",
|
||||
{
|
||||
"model_name": "threed_rotate",
|
||||
"task": "regression",
|
||||
"label_range": [100, 200],
|
||||
"input_shape": [1, 80, 80],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||||
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate")
|
||||
assert result["raw"] == "125.0"
|
||||
assert result["result"] == "125"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# SlideSolver 测试
|
||||
# ============================================================
|
||||
|
||||
423
tests/test_server.py
Normal file
423
tests/test_server.py
Normal file
@@ -0,0 +1,423 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from urllib.error import URLError
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("fastapi")
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from config import SERVER_CONFIG
|
||||
import server as server_module
|
||||
from server import create_app
|
||||
|
||||
|
||||
class _FakePipeline:
|
||||
def solve(self, image, captcha_type=None):
|
||||
return {
|
||||
"type": captcha_type or "normal",
|
||||
"result": "A3B8",
|
||||
"raw": "A3B8",
|
||||
"time_ms": 1.23,
|
||||
}
|
||||
|
||||
|
||||
class _FakeFunPipeline:
|
||||
def solve(self, image):
|
||||
return {
|
||||
"type": "funcaptcha",
|
||||
"question": "4_3d_rollball_animals",
|
||||
"objects": [2],
|
||||
"result": "2",
|
||||
"raw": "2",
|
||||
"time_ms": 2.34,
|
||||
}
|
||||
|
||||
|
||||
def _create_test_app(funcaptcha_factories=None):
|
||||
return create_app(
|
||||
pipeline_factory=_FakePipeline,
|
||||
funcaptcha_factories=funcaptcha_factories,
|
||||
)
|
||||
|
||||
|
||||
def _get_route(app, path: str):
|
||||
for route in app.routes:
|
||||
if getattr(route, "path", None) == path:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"route not found: {path}")
|
||||
|
||||
|
||||
def _fake_request(host: str = "127.0.0.1"):
|
||||
return SimpleNamespace(client=SimpleNamespace(host=host))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_server_config(monkeypatch, tmp_path):
|
||||
monkeypatch.setitem(SERVER_CONFIG, "client_key", None)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "tasks_dir", str(tmp_path / "server_tasks"))
|
||||
monkeypatch.setitem(SERVER_CONFIG, "task_cost", 0.0)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_max_retries", 2)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_delay_seconds", 1.0)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_backoff", 2.0)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_signing_secret", None)
|
||||
|
||||
|
||||
def test_solve_base64_returns_sync_payload():
|
||||
app = _create_test_app()
|
||||
solve = _get_route(app, "/api/v1/solve")
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
response = asyncio.run(
|
||||
solve(SimpleNamespace(image=encoded, type="math"))
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"type": "math",
|
||||
"result": "A3B8",
|
||||
"raw": "A3B8",
|
||||
"time_ms": 1.23,
|
||||
}
|
||||
|
||||
|
||||
def test_create_task_and_get_task_result():
|
||||
app = _create_test_app()
|
||||
create_task = _get_route(app, "/createTask")
|
||||
get_task_result = _get_route(app, "/getTaskResult")
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
create_response = asyncio.run(
|
||||
create_task(
|
||||
SimpleNamespace(
|
||||
clientKey="local",
|
||||
task=SimpleNamespace(
|
||||
type="ImageToTextTaskM1",
|
||||
body=encoded,
|
||||
image=None,
|
||||
captchaType="normal",
|
||||
),
|
||||
),
|
||||
_fake_request("10.0.0.8"),
|
||||
)
|
||||
)
|
||||
|
||||
assert create_response["errorId"] == 0
|
||||
assert create_response["status"] == "processing"
|
||||
assert isinstance(create_response["createTime"], int)
|
||||
assert isinstance(create_response["expiresAt"], int)
|
||||
task_id = create_response["taskId"]
|
||||
|
||||
result_response = None
|
||||
for _ in range(20):
|
||||
result_response = asyncio.run(
|
||||
get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
|
||||
)
|
||||
if result_response.get("status") == "ready":
|
||||
break
|
||||
|
||||
assert result_response is not None
|
||||
assert result_response["errorId"] == 0
|
||||
assert result_response["status"] == "ready"
|
||||
assert result_response["solution"] == {
|
||||
"text": "A3B8",
|
||||
"answer": "A3B8",
|
||||
"raw": "A3B8",
|
||||
"captchaType": "normal",
|
||||
"timeMs": 1.23,
|
||||
}
|
||||
assert result_response["cost"] == "0.00000"
|
||||
assert result_response["ip"] == "10.0.0.8"
|
||||
assert result_response["solveCount"] == 1
|
||||
assert result_response["task"] == {
|
||||
"type": "ImageToTextTaskM1",
|
||||
"captchaType": "normal",
|
||||
}
|
||||
assert result_response["callback"] == {
|
||||
"configured": False,
|
||||
"url": None,
|
||||
"attempts": 0,
|
||||
"delivered": False,
|
||||
"deliveredAt": None,
|
||||
"lastError": None,
|
||||
}
|
||||
assert isinstance(result_response["expiresAt"], int)
|
||||
|
||||
|
||||
def test_get_task_result_returns_not_found_for_unknown_task():
|
||||
app = _create_test_app()
|
||||
get_task_result = _get_route(app, "/getTaskResult")
|
||||
|
||||
response = asyncio.run(
|
||||
get_task_result(SimpleNamespace(clientKey="local", taskId="missing-task"))
|
||||
)
|
||||
|
||||
assert response["errorCode"] == "ERROR_TASK_NOT_FOUND"
|
||||
|
||||
|
||||
def test_solve_returns_json_error_for_invalid_base64():
|
||||
app = _create_test_app()
|
||||
solve = _get_route(app, "/solve")
|
||||
|
||||
response = asyncio.run(
|
||||
solve(SimpleNamespace(image="not_base64!", type="normal"))
|
||||
)
|
||||
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_health_alias_reports_client_key_flag(monkeypatch):
|
||||
app = _create_test_app()
|
||||
health = _get_route(app, "/api/v1/health")
|
||||
|
||||
monkeypatch.setitem(SERVER_CONFIG, "client_key", "secret")
|
||||
response = health()
|
||||
|
||||
assert response["status"] == "ok"
|
||||
assert response["client_key_required"] is True
|
||||
assert "ImageToTextTask" in response["supported_task_types"]
|
||||
|
||||
|
||||
def test_client_key_is_required_for_task_api(monkeypatch):
|
||||
app = _create_test_app()
|
||||
create_task = _get_route(app, "/api/v1/createTask")
|
||||
get_balance = _get_route(app, "/api/v1/getBalance")
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
monkeypatch.setitem(SERVER_CONFIG, "client_key", "secret")
|
||||
|
||||
create_response = asyncio.run(
|
||||
create_task(
|
||||
SimpleNamespace(
|
||||
clientKey="wrong-key",
|
||||
task=SimpleNamespace(
|
||||
type="ImageToTextTask",
|
||||
body=encoded,
|
||||
image=None,
|
||||
captchaType="normal",
|
||||
),
|
||||
),
|
||||
_fake_request(),
|
||||
)
|
||||
)
|
||||
balance_response = asyncio.run(
|
||||
get_balance(SimpleNamespace(clientKey="wrong-key"))
|
||||
)
|
||||
|
||||
assert create_response["errorCode"] == "ERROR_KEY_DOES_NOT_EXIST"
|
||||
assert balance_response["errorCode"] == "ERROR_KEY_DOES_NOT_EXIST"
|
||||
|
||||
|
||||
def test_create_task_triggers_callback(monkeypatch):
|
||||
app = _create_test_app()
|
||||
create_task = _get_route(app, "/createTask")
|
||||
callbacks = []
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
def _fake_post_callback(callback_url, payload):
|
||||
callbacks.append((callback_url, payload))
|
||||
|
||||
monkeypatch.setattr(server_module, "_post_callback", _fake_post_callback)
|
||||
|
||||
response = asyncio.run(
|
||||
create_task(
|
||||
SimpleNamespace(
|
||||
clientKey="local",
|
||||
callbackUrl="https://example.com/callback",
|
||||
task=SimpleNamespace(
|
||||
type="ImageToTextTask",
|
||||
body=encoded,
|
||||
image=None,
|
||||
captchaType="normal",
|
||||
),
|
||||
),
|
||||
_fake_request("10.0.0.9"),
|
||||
)
|
||||
)
|
||||
|
||||
assert response["errorId"] == 0
|
||||
|
||||
for _ in range(20):
|
||||
if callbacks:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
assert callbacks == [
|
||||
(
|
||||
"https://example.com/callback",
|
||||
{
|
||||
"id": response["taskId"],
|
||||
"taskId": response["taskId"],
|
||||
"status": "ready",
|
||||
"errorId": "0",
|
||||
"code": "A3B8",
|
||||
"text": "A3B8",
|
||||
"answer": "A3B8",
|
||||
"raw": "A3B8",
|
||||
"captchaType": "normal",
|
||||
"timeMs": "1.23",
|
||||
"cost": "0.00000",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_create_task_routes_fun_captcha_question():
|
||||
app = _create_test_app(
|
||||
funcaptcha_factories={"4_3d_rollball_animals": _FakeFunPipeline}
|
||||
)
|
||||
create_task = _get_route(app, "/createTask")
|
||||
get_task_result = _get_route(app, "/getTaskResult")
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
create_response = asyncio.run(
|
||||
create_task(
|
||||
SimpleNamespace(
|
||||
clientKey="local",
|
||||
task=SimpleNamespace(
|
||||
type="FunCaptcha",
|
||||
body=encoded,
|
||||
image=None,
|
||||
captchaType=None,
|
||||
question="4_3d_rollball_animals",
|
||||
),
|
||||
),
|
||||
_fake_request("10.0.0.7"),
|
||||
)
|
||||
)
|
||||
|
||||
task_id = create_response["taskId"]
|
||||
result_response = None
|
||||
for _ in range(20):
|
||||
result_response = asyncio.run(
|
||||
get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
|
||||
)
|
||||
if result_response.get("status") == "ready":
|
||||
break
|
||||
|
||||
assert result_response["errorId"] == 0
|
||||
assert result_response["status"] == "ready"
|
||||
assert result_response["solution"] == {
|
||||
"objects": [2],
|
||||
"answer": 2,
|
||||
"raw": "2",
|
||||
"timeMs": 2.34,
|
||||
"question": "4_3d_rollball_animals",
|
||||
"text": "2",
|
||||
}
|
||||
assert result_response["task"] == {
|
||||
"type": "FunCaptcha",
|
||||
"captchaType": None,
|
||||
"question": "4_3d_rollball_animals",
|
||||
}
|
||||
|
||||
|
||||
def test_create_task_retries_callback(monkeypatch):
|
||||
app = _create_test_app()
|
||||
create_task = _get_route(app, "/createTask")
|
||||
attempts = []
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
def _flaky_post_callback(callback_url, payload):
|
||||
attempts.append((callback_url, payload["taskId"]))
|
||||
if len(attempts) < 3:
|
||||
raise URLError("temporary failure")
|
||||
|
||||
monkeypatch.setattr(server_module, "_post_callback", _flaky_post_callback)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_delay_seconds", 0.0)
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_retry_backoff", 1.0)
|
||||
|
||||
response = asyncio.run(
|
||||
create_task(
|
||||
SimpleNamespace(
|
||||
clientKey="local",
|
||||
callbackUrl="https://example.com/callback",
|
||||
task=SimpleNamespace(
|
||||
type="ImageToTextTask",
|
||||
body=encoded,
|
||||
image=None,
|
||||
captchaType="normal",
|
||||
),
|
||||
),
|
||||
_fake_request(),
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(20):
|
||||
if len(attempts) >= 3:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
assert response["errorId"] == 0
|
||||
assert response["status"] == "processing"
|
||||
assert attempts == [
|
||||
("https://example.com/callback", response["taskId"]),
|
||||
("https://example.com/callback", response["taskId"]),
|
||||
("https://example.com/callback", response["taskId"]),
|
||||
]
|
||||
|
||||
|
||||
def test_tasks_are_restored_from_disk():
|
||||
app = _create_test_app()
|
||||
create_task = _get_route(app, "/createTask")
|
||||
get_task_result = _get_route(app, "/getTaskResult")
|
||||
encoded = base64.b64encode(b"fake-image-bytes").decode("ascii")
|
||||
|
||||
create_response = asyncio.run(
|
||||
create_task(
|
||||
SimpleNamespace(
|
||||
clientKey="local",
|
||||
task=SimpleNamespace(
|
||||
type="ImageToTextTask",
|
||||
body=encoded,
|
||||
image=None,
|
||||
captchaType="normal",
|
||||
),
|
||||
),
|
||||
_fake_request(),
|
||||
)
|
||||
)
|
||||
task_id = create_response["taskId"]
|
||||
|
||||
for _ in range(20):
|
||||
result = asyncio.run(
|
||||
get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
|
||||
)
|
||||
if result.get("status") == "ready":
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
task_file = Path(SERVER_CONFIG["tasks_dir"]) / f"{task_id}.json"
|
||||
assert task_file.exists()
|
||||
|
||||
reloaded_app = _create_test_app()
|
||||
reloaded_get_task_result = _get_route(reloaded_app, "/getTaskResult")
|
||||
reloaded_result = asyncio.run(
|
||||
reloaded_get_task_result(SimpleNamespace(clientKey="local", taskId=task_id))
|
||||
)
|
||||
|
||||
assert reloaded_result["status"] == "ready"
|
||||
assert reloaded_result["solution"]["answer"] == "A3B8"
|
||||
|
||||
|
||||
def test_callback_request_includes_signature_headers(monkeypatch):
|
||||
monkeypatch.setitem(SERVER_CONFIG, "callback_signing_secret", "shared-secret")
|
||||
|
||||
request = server_module._build_callback_request(
|
||||
"https://example.com/callback",
|
||||
{"taskId": "abc", "status": "ready"},
|
||||
)
|
||||
|
||||
body = urlencode({"taskId": "abc", "status": "ready"}).encode("utf-8")
|
||||
headers = {key.lower(): value for key, value in request.header_items()}
|
||||
timestamp = headers["x-captchabreaker-timestamp"]
|
||||
signature = headers["x-captchabreaker-signature"]
|
||||
|
||||
assert headers["content-type"] == "application/x-www-form-urlencoded"
|
||||
assert headers["x-captchabreaker-signature-alg"] == "hmac-sha256"
|
||||
assert signature == server_module._sign_callback_payload(body, timestamp, "shared-secret")
|
||||
Reference in New Issue
Block a user