admin管理员组

文章数量:1582013

在做的实验基础代码是用的 Pytorch-Lightning 中的训练器 Trainer 进行训练

  1. 首先需要保存的训练后的模型参数,保存 checkpoint 断点
checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=args.ckpt_dir + "/" + args.model_type,
        filename=model_savename + "---{epoch}---" + dt_string +'-'+str(args.use_img)+str(args.use_att)+str(args.use_date)+str(args.use_trends)+'RNN3_5',#str(note)
        monitor="val_mae",
        mode="min",#这里实验效果是越小越好,所以是“min”
        save_top_k=5,#1
    )
    
print(checkpoint_callback.best_model_path)#打印出效果最好的模型参数存储的路径

这里保存了效果前五的模型,这里的实验效果是越小越好,并打印出效果最好的模型参数存储路径

  1. 在训练器 Trainer 里加载之前保存的最佳模型
 trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=testloader,ckpt_path='自己替换成最佳模型参数所存在的路径.ckpt')

主要是 trainer.fit() 函数里,ckpt_path 参数所提供的效果,输入 ckpt 文件路径(从这里文件恢复训练)

参考博客:https://blog.csdn/qq_27135095/article/details/122635743?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167583461916800180668936%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167583461916800180668936&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_ecpm_v1~rank_v31_ecpm-1-122635743-null-null.142%5Ev73%5Econtrol,201%5Ev4%5Eadd_ask,239%5Ev1%5Einsert_chatgpt&utm_term=pl%20trainer%20%E6%98%AF%E5%A6%82%E4%BD%95%E8%AE%AD%E7%BB%83%E7%9A%84&spm=1018.2226.3001.4187

本文标签: 深度参数ckptTrainerPytorch