Add tests, server, resume training, and project cleanup
- Add 57 unit tests covering generators, models, and pipeline components - Implement FastAPI HTTP service (server.py) with POST /solve and GET /health - Add checkpoint resume (断点续训) to both CTC and regression training utils - Fix device mismatch bug in CTC training (targets/input_lengths on GPU) - Add pytest dev dependency to pyproject.toml - Update .gitignore with data/solver/, data/real/, *.log - Remove PyCharm template main.py - Update training/__init__.py docs for solver training scripts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,4 +10,6 @@
|
||||
- train_3d_rotate.py: 训练 3D 旋转回归 (RegressionCNN)
|
||||
- train_3d_slider.py: 训练 3D 滑块回归 (RegressionCNN)
|
||||
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
|
||||
- train_slide.py: 训练滑块缺口检测 (GapDetectorCNN)
|
||||
- train_rotate_solver.py: 训练旋转角度回归 (RotationRegressor)
|
||||
"""
|
||||
|
||||
@@ -175,10 +175,26 @@ def train_regression_model(
|
||||
|
||||
best_mae = float("inf")
|
||||
best_tol_acc = 0.0
|
||||
start_epoch = 1
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
|
||||
best_mae = ckpt.get("best_mae", float("inf"))
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
@@ -170,10 +170,25 @@ def train_ctc_model(
|
||||
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
|
||||
|
||||
best_acc = 0.0
|
||||
start_epoch = 1
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = ckpt.get("best_acc", 0.0)
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_acc={best_acc:.4f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
@@ -181,11 +196,12 @@ def train_ctc_model(
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets, target_lengths, _ in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
logits = model(images) # (T, B, C)
|
||||
T, B, C = logits.shape
|
||||
# cuDNN CTC requires targets/lengths on CPU
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32)
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||||
|
||||
log_probs = logits.log_softmax(2)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
|
||||
Reference in New Issue
Block a user