Initialize repository
This commit is contained in:
40
training/train_normal.py
Normal file
40
training/train_normal.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
训练普通字符识别模型 (LiteCRNN - normal 模式)
|
||||
|
||||
用法: python -m training.train_normal
|
||||
"""
|
||||
|
||||
from config import (
|
||||
NORMAL_CHARS,
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
REAL_NORMAL_DIR,
|
||||
)
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from training.train_utils import train_ctc_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["normal"]
|
||||
model = LiteCRNN(chars=NORMAL_CHARS, img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练普通字符识别模型 (LiteCRNN - normal)")
|
||||
print(f" 字符集: {NORMAL_CHARS} ({len(NORMAL_CHARS)} 字符)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_ctc_model(
|
||||
model_name="normal",
|
||||
model=model,
|
||||
chars=NORMAL_CHARS,
|
||||
synthetic_dir=SYNTHETIC_NORMAL_DIR,
|
||||
real_dir=REAL_NORMAL_DIR,
|
||||
generator_cls=NormalCaptchaGenerator,
|
||||
config_key="normal",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user