Add tests, server, resume training, and project cleanup
- Add 57 unit tests covering generators, models, and pipeline components - Implement FastAPI HTTP service (server.py) with POST /solve and GET /health - Add checkpoint resume (断点续训) to both CTC and regression training utils - Fix device mismatch bug in CTC training (targets/input_lengths on GPU) - Add pytest dev dependency to pyproject.toml - Update .gitignore with data/solver/, data/real/, *.log - Remove PyCharm template main.py - Update training/__init__.py docs for solver training scripts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -175,10 +175,26 @@ def train_regression_model(
|
||||
|
||||
best_mae = float("inf")
|
||||
best_tol_acc = 0.0
|
||||
start_epoch = 1
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
|
||||
best_mae = ckpt.get("best_mae", float("inf"))
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
Reference in New Issue
Block a user