Add tests, server, resume training, and project cleanup
- Add 57 unit tests covering generators, models, and pipeline components - Implement FastAPI HTTP service (server.py) with POST /solve and GET /health - Add checkpoint resume (断点续训) to both CTC and regression training utils - Fix device mismatch bug in CTC training (targets/input_lengths on GPU) - Add pytest dev dependency to pyproject.toml - Update .gitignore with data/solver/, data/real/, *.log - Remove PyCharm template main.py - Update training/__init__.py docs for solver training scripts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -170,10 +170,25 @@ def train_ctc_model(
|
||||
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
|
||||
|
||||
best_acc = 0.0
|
||||
start_epoch = 1
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = ckpt.get("best_acc", 0.0)
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_acc={best_acc:.4f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
@@ -181,11 +196,12 @@ def train_ctc_model(
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets, target_lengths, _ in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
logits = model(images) # (T, B, C)
|
||||
T, B, C = logits.shape
|
||||
# cuDNN CTC requires targets/lengths on CPU
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32)
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||||
|
||||
log_probs = logits.log_softmax(2)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
|
||||
Reference in New Issue
Block a user