Initialize repository

This commit is contained in:
Hua
2026-03-10 18:47:29 +08:00
commit 760b80ee5e
32 changed files with 4343 additions and 0 deletions

228
cli.py Normal file
View File

@@ -0,0 +1,228 @@
"""
CaptchaBreaker 命令行入口
用法:
python cli.py generate --type normal --num 60000
python cli.py train --model normal
python cli.py train --all
python cli.py export --all
python cli.py predict image.png
python cli.py predict image.png --type normal
python cli.py predict-dir ./test_images/
python cli.py serve --port 8080
"""
import argparse
import sys
from pathlib import Path
def cmd_generate(args):
"""生成训练数据。"""
from config import (
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
)
from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator
gen_map = {
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
"3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR),
}
captcha_type = args.type
num = args.num
if captcha_type == "classifier":
# 分类器数据: 各类型各生成 num // num_types
per_class = num // NUM_CAPTCHA_TYPES
print(f"生成分类器训练数据: 每类 {per_class}")
for cls_name in CAPTCHA_TYPES:
gen_cls, out_dir = gen_map[cls_name]
cls_dir = CLASSIFIER_DIR / cls_name
cls_dir.mkdir(parents=True, exist_ok=True)
gen = gen_cls()
gen.generate_dataset(per_class, str(cls_dir))
elif captcha_type in gen_map:
gen_cls, out_dir = gen_map[captcha_type]
print(f"生成 {captcha_type} 数据: {num} 张 → {out_dir}")
gen = gen_cls()
gen.generate_dataset(num, str(out_dir))
else:
print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier")
sys.exit(1)
def cmd_train(args):
"""训练模型。"""
if args.all:
# 按依赖顺序: normal → math → 3d → classifier
print("按顺序训练全部模型: normal → math → 3d → classifier\n")
from training.train_normal import main as train_normal
from training.train_math import main as train_math
from training.train_3d import main as train_3d
from training.train_classifier import main as train_classifier
train_normal()
print("\n")
train_math()
print("\n")
train_3d()
print("\n")
train_classifier()
return
model = args.model
if model == "normal":
from training.train_normal import main as train_fn
elif model == "math":
from training.train_math import main as train_fn
elif model == "3d":
from training.train_3d import main as train_fn
elif model == "classifier":
from training.train_classifier import main as train_fn
else:
print(f"未知模型: {model} 可选: normal, math, 3d, classifier")
sys.exit(1)
train_fn()
def cmd_export(args):
"""导出 ONNX 模型。"""
from inference.export_onnx import export_all, _load_and_export
if args.all:
export_all()
elif args.model:
_load_and_export(args.model)
else:
print("请指定 --all 或 --model <name>")
sys.exit(1)
def cmd_predict(args):
"""单张图片推理。"""
from inference.pipeline import CaptchaPipeline
image_path = args.image
if not Path(image_path).exists():
print(f"文件不存在: {image_path}")
sys.exit(1)
pipeline = CaptchaPipeline()
result = pipeline.solve(image_path, captcha_type=args.type)
print(f"文件: {image_path}")
print(f"类型: {result['type']}")
print(f"识别: {result['raw']}")
print(f"结果: {result['result']}")
print(f"耗时: {result['time_ms']:.1f} ms")
def cmd_predict_dir(args):
"""批量目录推理。"""
from inference.pipeline import CaptchaPipeline
dir_path = Path(args.directory)
if not dir_path.is_dir():
print(f"目录不存在: {dir_path}")
sys.exit(1)
pipeline = CaptchaPipeline()
images = sorted(dir_path.glob("*.png")) + sorted(dir_path.glob("*.jpg"))
if not images:
print(f"目录中未找到图片: {dir_path}")
sys.exit(1)
print(f"批量识别: {len(images)} 张图片\n")
print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}")
print("-" * 65)
total_ms = 0.0
for img_path in images:
result = pipeline.solve(str(img_path), captcha_type=args.type)
total_ms += result["time_ms"]
print(
f"{img_path.name:<30} {result['type']:<8} "
f"{result['result']:<15} {result['time_ms']:>8.1f}"
)
print("-" * 65)
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
def cmd_serve(args):
"""启动 HTTP 服务。"""
try:
from server import create_app
except ImportError:
# server.py 尚未实现或缺少依赖
print("HTTP 服务需要 FastAPI 和 uvicorn。")
print("安装: uv sync --extra server")
print("并确保 server.py 已实现。")
sys.exit(1)
import uvicorn
app = create_app()
uvicorn.run(app, host=args.host, port=args.port)
def main():
parser = argparse.ArgumentParser(
prog="captcha-breaker",
description="验证码识别多模型系统 - 调度模型 + 多专家模型",
)
subparsers = parser.add_subparsers(dest="command", help="子命令")
# ---- generate ----
p_gen = subparsers.add_parser("generate", help="生成训练数据")
p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier")
p_gen.add_argument("--num", type=int, required=True, help="生成数量")
# ---- train ----
p_train = subparsers.add_parser("train", help="训练模型")
p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier")
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
# ---- export ----
p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed")
p_export.add_argument("--all", action="store_true", help="导出全部模型")
# ---- predict ----
p_pred = subparsers.add_parser("predict", help="识别单张验证码")
p_pred.add_argument("image", help="图片路径")
p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
# ---- predict-dir ----
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
p_pdir.add_argument("directory", help="图片目录路径")
p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
# ---- serve ----
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(0)
cmd_map = {
"generate": cmd_generate,
"train": cmd_train,
"export": cmd_export,
"predict": cmd_predict,
"predict-dir": cmd_predict_dir,
"serve": cmd_serve,
}
cmd_map[args.command](args)
if __name__ == "__main__":
main()