446 lines
15 KiB
Python
446 lines
15 KiB
Python
"""
|
||
CaptchaBreaker 命令行入口
|
||
|
||
用法:
|
||
python cli.py generate --type normal --num 60000
|
||
python cli.py generate --type 3d_text --num 80000
|
||
python cli.py generate --type 3d_rotate --num 60000
|
||
python cli.py generate --type 3d_slider --num 60000
|
||
python cli.py train --model normal
|
||
python cli.py train --all
|
||
python cli.py export --all
|
||
python cli.py predict image.png
|
||
python cli.py predict image.png --type normal
|
||
python cli.py predict-dir ./test_images/
|
||
python cli.py serve --port 8080
|
||
python cli.py generate-solver slide --num 30000
|
||
python cli.py train-solver slide
|
||
python cli.py train-solver rotate
|
||
python cli.py solve slide --bg bg.png [--tpl tpl.png]
|
||
python cli.py solve rotate --image img.png
|
||
python cli.py train-funcaptcha --question 4_3d_rollball_animals
|
||
python cli.py predict-funcaptcha image.jpg --question 4_3d_rollball_animals
|
||
"""
|
||
|
||
import argparse
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
|
||
def cmd_generate(args):
|
||
"""生成训练数据。"""
|
||
from config import (
|
||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
|
||
SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
|
||
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
|
||
)
|
||
from generators import (
|
||
NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator,
|
||
ThreeDRotateGenerator, ThreeDSliderGenerator,
|
||
)
|
||
|
||
gen_map = {
|
||
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
|
||
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
|
||
"3d_text": (ThreeDCaptchaGenerator, SYNTHETIC_3D_TEXT_DIR),
|
||
"3d_rotate": (ThreeDRotateGenerator, SYNTHETIC_3D_ROTATE_DIR),
|
||
"3d_slider": (ThreeDSliderGenerator, SYNTHETIC_3D_SLIDER_DIR),
|
||
}
|
||
|
||
captcha_type = args.type
|
||
num = args.num
|
||
|
||
if captcha_type == "classifier":
|
||
# 分类器数据: 各类型各生成 num // num_types
|
||
per_class = num // NUM_CAPTCHA_TYPES
|
||
print(f"生成分类器训练数据: 每类 {per_class} 张")
|
||
for cls_name in CAPTCHA_TYPES:
|
||
gen_cls, out_dir = gen_map[cls_name]
|
||
cls_dir = CLASSIFIER_DIR / cls_name
|
||
cls_dir.mkdir(parents=True, exist_ok=True)
|
||
gen = gen_cls()
|
||
gen.generate_dataset(per_class, str(cls_dir))
|
||
elif captcha_type in gen_map:
|
||
gen_cls, out_dir = gen_map[captcha_type]
|
||
print(f"生成 {captcha_type} 数据: {num} 张 → {out_dir}")
|
||
gen = gen_cls()
|
||
gen.generate_dataset(num, str(out_dir))
|
||
else:
|
||
valid = ", ".join(list(gen_map.keys()) + ["classifier"])
|
||
print(f"未知类型: {captcha_type} 可选: {valid}")
|
||
sys.exit(1)
|
||
|
||
|
||
def cmd_train(args):
|
||
"""训练模型。"""
|
||
if args.all:
|
||
print("按顺序训练全部模型: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier\n")
|
||
from training.train_normal import main as train_normal
|
||
from training.train_math import main as train_math
|
||
from training.train_3d_text import main as train_3d_text
|
||
from training.train_3d_rotate import main as train_3d_rotate
|
||
from training.train_3d_slider import main as train_3d_slider
|
||
from training.train_classifier import main as train_classifier
|
||
|
||
train_normal()
|
||
print("\n")
|
||
train_math()
|
||
print("\n")
|
||
train_3d_text()
|
||
print("\n")
|
||
train_3d_rotate()
|
||
print("\n")
|
||
train_3d_slider()
|
||
print("\n")
|
||
train_classifier()
|
||
return
|
||
|
||
model = args.model
|
||
if model == "normal":
|
||
from training.train_normal import main as train_fn
|
||
elif model == "math":
|
||
from training.train_math import main as train_fn
|
||
elif model == "3d_text":
|
||
from training.train_3d_text import main as train_fn
|
||
elif model == "3d_rotate":
|
||
from training.train_3d_rotate import main as train_fn
|
||
elif model == "3d_slider":
|
||
from training.train_3d_slider import main as train_fn
|
||
elif model == "classifier":
|
||
from training.train_classifier import main as train_fn
|
||
else:
|
||
print(f"未知模型: {model} 可选: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier")
|
||
sys.exit(1)
|
||
|
||
train_fn()
|
||
|
||
|
||
def cmd_export(args):
|
||
"""导出 ONNX 模型。"""
|
||
from inference.export_onnx import export_all, _load_and_export
|
||
|
||
if args.all:
|
||
export_all()
|
||
elif args.model:
|
||
# 别名映射
|
||
alias = {
|
||
"3d_text": "threed_text",
|
||
"3d_rotate": "threed_rotate",
|
||
"3d_slider": "threed_slider",
|
||
"4_3d_rollball_animals": "funcaptcha_rollball_animals",
|
||
}
|
||
name = alias.get(args.model, args.model)
|
||
_load_and_export(name)
|
||
else:
|
||
print("请指定 --all 或 --model <name>")
|
||
sys.exit(1)
|
||
|
||
|
||
def cmd_predict(args):
|
||
"""单张图片推理。"""
|
||
from inference.pipeline import CaptchaPipeline
|
||
|
||
image_path = args.image
|
||
if not Path(image_path).exists():
|
||
print(f"文件不存在: {image_path}")
|
||
sys.exit(1)
|
||
|
||
pipeline = CaptchaPipeline()
|
||
result = pipeline.solve(image_path, captcha_type=args.type)
|
||
|
||
print(f"文件: {image_path}")
|
||
print(f"类型: {result['type']}")
|
||
print(f"识别: {result['raw']}")
|
||
print(f"结果: {result['result']}")
|
||
print(f"耗时: {result['time_ms']:.1f} ms")
|
||
|
||
|
||
def cmd_predict_dir(args):
|
||
"""批量目录推理。"""
|
||
from inference.pipeline import CaptchaPipeline
|
||
|
||
dir_path = Path(args.directory)
|
||
if not dir_path.is_dir():
|
||
print(f"目录不存在: {dir_path}")
|
||
sys.exit(1)
|
||
|
||
pipeline = CaptchaPipeline()
|
||
images = sorted(dir_path.glob("*.png")) + sorted(dir_path.glob("*.jpg"))
|
||
if not images:
|
||
print(f"目录中未找到图片: {dir_path}")
|
||
sys.exit(1)
|
||
|
||
print(f"批量识别: {len(images)} 张图片\n")
|
||
print(f"{'文件名':<30} {'类型':<10} {'结果':<15} {'耗时(ms)':>8}")
|
||
print("-" * 67)
|
||
|
||
total_ms = 0.0
|
||
for img_path in images:
|
||
result = pipeline.solve(str(img_path), captcha_type=args.type)
|
||
total_ms += result["time_ms"]
|
||
print(
|
||
f"{img_path.name:<30} {result['type']:<10} "
|
||
f"{result['result']:<15} {result['time_ms']:>8.1f}"
|
||
)
|
||
|
||
print("-" * 67)
|
||
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
|
||
|
||
|
||
def cmd_serve(args):
|
||
"""启动 HTTP 服务。"""
|
||
try:
|
||
from server import create_app
|
||
except ImportError:
|
||
# server.py 尚未实现或缺少依赖
|
||
print("HTTP 服务需要 FastAPI 和 uvicorn。")
|
||
print("安装: uv sync --extra server")
|
||
print("并确保 server.py 已实现。")
|
||
sys.exit(1)
|
||
|
||
import uvicorn
|
||
app = create_app()
|
||
uvicorn.run(app, host=args.host, port=args.port)
|
||
|
||
|
||
def cmd_generate_solver(args):
|
||
"""生成 solver 训练数据。"""
|
||
from config import SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR
|
||
from generators.slide_gen import SlideDataGenerator
|
||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||
|
||
solver_type = args.type
|
||
num = args.num
|
||
|
||
gen_map = {
|
||
"slide": (SlideDataGenerator, SLIDE_DATA_DIR),
|
||
"rotate": (RotateSolverDataGenerator, ROTATE_SOLVER_DATA_DIR),
|
||
}
|
||
|
||
if solver_type not in gen_map:
|
||
print(f"未知 solver 类型: {solver_type} 可选: {', '.join(gen_map.keys())}")
|
||
sys.exit(1)
|
||
|
||
gen_cls, out_dir = gen_map[solver_type]
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
print(f"生成 solver/{solver_type} 数据: {num} 张 → {out_dir}")
|
||
gen = gen_cls()
|
||
gen.generate_dataset(num, str(out_dir))
|
||
|
||
|
||
def cmd_train_solver(args):
|
||
"""训练 solver 模型。"""
|
||
solver_type = args.type
|
||
|
||
if solver_type == "slide":
|
||
from training.train_slide import main as train_fn
|
||
elif solver_type == "rotate":
|
||
from training.train_rotate_solver import main as train_fn
|
||
else:
|
||
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
|
||
sys.exit(1)
|
||
|
||
train_fn()
|
||
|
||
|
||
def cmd_solve(args):
|
||
"""求解验证码。"""
|
||
solver_type = args.type
|
||
|
||
if solver_type == "slide":
|
||
from solvers.slide_solver import SlideSolver
|
||
|
||
bg_path = args.bg
|
||
tpl_path = getattr(args, "tpl", None)
|
||
if not Path(bg_path).exists():
|
||
print(f"文件不存在: {bg_path}")
|
||
sys.exit(1)
|
||
|
||
solver = SlideSolver()
|
||
result = solver.solve(bg_path, template_image=tpl_path)
|
||
|
||
print(f"背景图: {bg_path}")
|
||
if tpl_path:
|
||
print(f"模板图: {tpl_path}")
|
||
print(f"缺口 x: {result['gap_x']} px")
|
||
print(f"缺口 x%: {result['gap_x_percent']:.4f}")
|
||
print(f"置信度: {result['confidence']:.4f}")
|
||
print(f"方法: {result['method']}")
|
||
|
||
elif solver_type == "rotate":
|
||
from solvers.rotate_solver import RotateSolver
|
||
|
||
image_path = args.image
|
||
if not Path(image_path).exists():
|
||
print(f"文件不存在: {image_path}")
|
||
sys.exit(1)
|
||
|
||
solver = RotateSolver()
|
||
result = solver.solve(image_path)
|
||
|
||
print(f"图片: {image_path}")
|
||
print(f"角度: {result['angle']}°")
|
||
print(f"置信度: {result['confidence']}")
|
||
|
||
else:
|
||
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
|
||
sys.exit(1)
|
||
|
||
|
||
def cmd_train_funcaptcha(args):
|
||
"""训练 FunCaptcha 专项模型。"""
|
||
from config import FUN_CAPTCHA_TASKS
|
||
from training.train_funcaptcha_rollball import main as train_rollball
|
||
|
||
question = args.question
|
||
if question not in FUN_CAPTCHA_TASKS:
|
||
print(f"未知 FunCaptcha question: {question} 可选: {', '.join(FUN_CAPTCHA_TASKS)}")
|
||
sys.exit(1)
|
||
|
||
if question == "4_3d_rollball_animals":
|
||
train_rollball(question=question)
|
||
return
|
||
|
||
print(f"暂未实现该 FunCaptcha 训练入口: {question}")
|
||
sys.exit(1)
|
||
|
||
|
||
def cmd_predict_funcaptcha(args):
|
||
"""专项 FunCaptcha 预测。"""
|
||
from config import FUN_CAPTCHA_TASKS
|
||
from inference.fun_captcha import FunCaptchaRollballPipeline
|
||
|
||
image_path = args.image
|
||
question = args.question
|
||
if not Path(image_path).exists():
|
||
print(f"文件不存在: {image_path}")
|
||
sys.exit(1)
|
||
if question not in FUN_CAPTCHA_TASKS:
|
||
print(f"未知 FunCaptcha question: {question} 可选: {', '.join(FUN_CAPTCHA_TASKS)}")
|
||
sys.exit(1)
|
||
|
||
if question == "4_3d_rollball_animals":
|
||
pipeline = FunCaptchaRollballPipeline(question=question)
|
||
else:
|
||
print(f"暂未实现该 FunCaptcha 预测入口: {question}")
|
||
sys.exit(1)
|
||
|
||
result = pipeline.solve(image_path)
|
||
print(f"文件: {image_path}")
|
||
print(f"question: {result['question']}")
|
||
print(f"objects: {result['objects']}")
|
||
print(f"result: {result['result']}")
|
||
print(f"耗时: {result['time_ms']:.1f} ms")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
prog="captcha-breaker",
|
||
description="验证码识别多模型系统 - 调度模型 + 多专家模型",
|
||
)
|
||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||
|
||
# ---- generate ----
|
||
p_gen = subparsers.add_parser("generate", help="生成训练数据")
|
||
p_gen.add_argument(
|
||
"--type", required=True,
|
||
help="验证码类型: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
|
||
)
|
||
p_gen.add_argument("--num", type=int, required=True, help="生成数量")
|
||
|
||
# ---- train ----
|
||
p_train = subparsers.add_parser("train", help="训练模型")
|
||
p_train.add_argument(
|
||
"--model",
|
||
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
|
||
)
|
||
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
|
||
|
||
# ---- export ----
|
||
p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
|
||
p_export.add_argument(
|
||
"--model",
|
||
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
|
||
)
|
||
p_export.add_argument("--all", action="store_true", help="导出全部模型")
|
||
|
||
# ---- predict ----
|
||
p_pred = subparsers.add_parser("predict", help="识别单张验证码")
|
||
p_pred.add_argument("image", help="图片路径")
|
||
p_pred.add_argument(
|
||
"--type", default=None,
|
||
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
|
||
)
|
||
|
||
# ---- predict-dir ----
|
||
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
|
||
p_pdir.add_argument("directory", help="图片目录路径")
|
||
p_pdir.add_argument(
|
||
"--type", default=None,
|
||
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
|
||
)
|
||
|
||
# ---- serve ----
|
||
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")
|
||
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
|
||
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
|
||
|
||
# ---- generate-solver ----
|
||
p_gen_solver = subparsers.add_parser("generate-solver", help="生成 solver 训练数据")
|
||
p_gen_solver.add_argument("type", help="solver 类型: slide, rotate")
|
||
p_gen_solver.add_argument("--num", type=int, required=True, help="生成数量")
|
||
|
||
# ---- train-solver ----
|
||
p_train_solver = subparsers.add_parser("train-solver", help="训练 solver 模型")
|
||
p_train_solver.add_argument("type", help="solver 类型: slide, rotate")
|
||
|
||
# ---- solve ----
|
||
p_solve = subparsers.add_parser("solve", help="求解交互式验证码")
|
||
p_solve.add_argument("type", help="solver 类型: slide, rotate")
|
||
p_solve.add_argument("--bg", help="背景图路径 (slide 必需)")
|
||
p_solve.add_argument("--tpl", default=None, help="模板图路径 (slide 可选)")
|
||
p_solve.add_argument("--image", help="图片路径 (rotate 必需)")
|
||
|
||
# ---- train-funcaptcha ----
|
||
p_train_fun = subparsers.add_parser("train-funcaptcha", help="训练 FunCaptcha 专项模型")
|
||
p_train_fun.add_argument(
|
||
"--question",
|
||
required=True,
|
||
help="专项 question,如: 4_3d_rollball_animals",
|
||
)
|
||
|
||
# ---- predict-funcaptcha ----
|
||
p_pred_fun = subparsers.add_parser("predict-funcaptcha", help="识别单张 FunCaptcha challenge")
|
||
p_pred_fun.add_argument("image", help="图片路径")
|
||
p_pred_fun.add_argument(
|
||
"--question",
|
||
required=True,
|
||
help="专项 question,如: 4_3d_rollball_animals",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.command is None:
|
||
parser.print_help()
|
||
sys.exit(0)
|
||
|
||
cmd_map = {
|
||
"generate": cmd_generate,
|
||
"train": cmd_train,
|
||
"export": cmd_export,
|
||
"predict": cmd_predict,
|
||
"predict-dir": cmd_predict_dir,
|
||
"serve": cmd_serve,
|
||
"generate-solver": cmd_generate_solver,
|
||
"train-solver": cmd_train_solver,
|
||
"solve": cmd_solve,
|
||
"train-funcaptcha": cmd_train_funcaptcha,
|
||
"predict-funcaptcha": cmd_predict_funcaptcha,
|
||
}
|
||
|
||
cmd_map[args.command](args)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|