New solver subsystem with independent models: - GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection - RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction - SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback - RotateSolver with ONNX sin/cos → atan2 inference - Generators, training scripts, CLI commands, and slide track utility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
66 lines
1.7 KiB
Python
66 lines
1.7 KiB
Python
"""
|
||
训练滑块缺口检测 CNN (GapDetectorCNN)
|
||
|
||
复用 train_regression_utils 的通用回归训练流程。
|
||
|
||
用法: python -m training.train_slide
|
||
"""
|
||
|
||
from config import (
|
||
SOLVER_CONFIG,
|
||
SOLVER_TRAIN_CONFIG,
|
||
SOLVER_REGRESSION_RANGE,
|
||
SLIDE_DATA_DIR,
|
||
CHECKPOINTS_DIR,
|
||
ONNX_DIR,
|
||
ONNX_CONFIG,
|
||
RANDOM_SEED,
|
||
get_device,
|
||
)
|
||
from generators.slide_gen import SlideDataGenerator
|
||
from models.gap_detector import GapDetectorCNN
|
||
|
||
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
|
||
# 以便复用 train_regression_utils
|
||
import config as _cfg
|
||
|
||
|
||
def main():
|
||
solver_cfg = SOLVER_CONFIG["slide"]
|
||
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
|
||
img_h, img_w = solver_cfg["cnn_input_size"]
|
||
|
||
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
|
||
|
||
print("=" * 60)
|
||
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
|
||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||
print(f" 任务: 预测缺口 x 坐标百分比")
|
||
print("=" * 60)
|
||
|
||
# 直接使用 train_regression_utils 中的逻辑
|
||
# 但需要临时注入配置
|
||
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
|
||
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
|
||
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
|
||
|
||
from training.train_regression_utils import train_regression_model
|
||
|
||
# 确保数据目录存在
|
||
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
real_dir = SLIDE_DATA_DIR / "real"
|
||
real_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
train_regression_model(
|
||
model_name="gap_detector",
|
||
model=model,
|
||
synthetic_dir=str(SLIDE_DATA_DIR),
|
||
real_dir=str(real_dir),
|
||
generator_cls=SlideDataGenerator,
|
||
config_key="slide_cnn",
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|