admin管理员组

文章数量:1536088

文章目录

  • 前言
  • 一、下载数据集
  • 二、训练推理过程
  • 三、推理时间
  • 四、注意


前言

这几天看到一个实时轻量的超分辨率网络,Bicubic++: Slim, Slimmer, Slimmest - Designing an Industry-Grade Super-Resolution Network (🚀 Winner of NTIRE RTSR Challange Track 2 (x3 SR) @ CVPR 2023),于是拉下来想看一下,整体效果还是很不错,网络结构非常简单,推理速度极快并且效果好。
官方的代码使用的pytorch_lightning,封装得很深,结构不是很容易去看,另外是将训练数据全部加载到内存里面,如果是自己的数据太大容易内存不足,开始调试也会加载很久,官方使用X3 SR,当超分辨率尺度不是3训练会报错,因为整除等问题验证时会报shape不一致,官方暂时没有放出推理代码。因此我重写了一套代码解决这些问题,当然从硬盘读取训练速度会慢一些,代码仓库。

一、下载数据集

官方使用的DIV2K数据集,训练900张图,是使用3倍下采样的数据训练模型,同时指定了40张图作为验证集计算psnr,我重写的repo中,前860张作为训练集,最后40张作为验证,数据已经分好。

val_HR和val_LR分别为高低分辨率图各40张,DIV2K_train_HR 为860张高分辨率图,DIV2K_train_LR_bicubic/X3为860张3倍下采样图像。可以使用自己的数据训练,高低分辨率文件夹中的图像名称最好保持一致一一对应。

二、训练推理过程

训练:python3 train.py
推理:python3 inference.py
pytorch模型转onnx模型:python3 torch2onnx.py
onnx推理:python3 onnx_inference.py

lr图像:

bicubic插值图像:
本文bicubic++算法:

三、推理时间

单张图像推理过程不算前后处理时间,CPU:13700k GPU: 4090

pytorch GPU: 0.5ms左右
pytorch CPU: 17ms左右
onnxruntime GPU:4ms左右
onnxruntime CPU: 15ms左右
onnxruntime更慢,猜测是因为网络结构不复杂,没有需要优化的操作,当然也可能是我代码写错了。

四、注意

1、conf.yaml配置文件中patch_cropsize必须能整除超分辨率尺度参数,比如这里是3倍,patch_cropsize/3为整数。
2、degradation参数,如果blur或者img_noise为True,则会走degradation的方法,也就是从hr下采样生成lr而不是从文件夹读lr图像,这个过程比较慢,一张图大概一秒,如果在训练中开启了,训练速度会变慢,建议可以直接使用该方法将lr图像生成到本地再训练。
如果对您有帮助,不妨star一下!!!

本文标签: 实时分辨率网络轻量超bicubic