Add tests, server, resume training, and project cleanup

- Add 57 unit tests covering generators, models, and pipeline components
- Implement FastAPI HTTP service (server.py) with POST /solve and GET /health
- Add checkpoint resume (断点续训) to both CTC and regression training utils
- Fix device mismatch bug in CTC training (targets/input_lengths on GPU)
- Add pytest dev dependency to pyproject.toml
- Update .gitignore with data/solver/, data/real/, *.log
- Remove PyCharm template main.py
- Update training/__init__.py docs for solver training scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 19:05:47 +08:00
parent 9b5f29083e
commit 788ddcae1a
11 changed files with 786 additions and 21 deletions

View File

@@ -170,10 +170,25 @@ def train_ctc_model(
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
best_acc = 0.0
start_epoch = 1
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
@@ -181,11 +196,12 @@ def train_ctc_model(
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
# cuDNN CTC requires targets/lengths on CPU
input_lengths = torch.full((B,), T, dtype=torch.int32)
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)