admin管理员组文章数量:1650794
文章目录
- RESA: Recurrent Feature-Shift Aggregator for Lane Detection
- advantages:
- 上采样
- 数据集
- 贡献
- 相关工作
- traditional methods
- deep learning method
- 空间信息利用
- 方法
- 框架设计
- RESA
- 优点
- 双边上采样译码器
- Coarse grained branch
- Fine detailed branch
- Experiment
- 代码复现
- 安装环境
- 数据集
- TuSimple
- 测试
- configs/tusimple.py
- datasets/base_dataset.py
- datasets/registry.py
- datasets/tusimple.py
- utils貌似是为configs服务的
- utils/registry.py
- models/decoder.py
- models/registry.py
- models/resa.py
- models/resnet.py
- runner/evaluator/tusimple/tusimple.py
- runner/logger.py
- runnner/net_utils.py
- runnner/optimizer.py
- runnner/recorder.py
- runner/registry.py
- runner/resa_trainer.py ----loss计算
- runner/scheduler.py 学习率衰减策略Lambda LR
- runner/runner.py
- main.py
- colab下代码复现
RESA: Recurrent Feature-Shift Aggregator for Lane Detection
- 现在的算法典型地应用像素级分割公式, 把车道线检测看作一个分割问题, 图像中的每个像素被二值标记标示是否属于车道线.
- 这些工作用encoder-decoder框架
- 首先,应用CNN作为编码器来提取高级语义特征到特征图, 然后用上采样解码器来恢复特征图到原始尺寸并且最终执行像素级预测.
- 车道线细长,样本比例不平衡,导致提取敏感特征且可能忽略形状先验和车道线间的相关性, 降低检测表现
- 可能被拥挤的车辆挡住,我们只能凭借常识推断车道线位置
- SCNN提出空间卷积来在行列之间传播信息,然而信息传递十分费时,导致推理速度较慢.同时在相邻的行或列之间依次传递,迭代次数多,可能导致长距离信息丢失.
- 这篇文章提出REcurrent Feature-Shift Aggregator (RESA),在特征图内聚集信息,并且更直接有效地传递空间信息.
advantages:
- RESA passes information in a parallel way, thus can reduce time cost greatly.
1)平行传递信息,减少时间消耗 - Information will be passed with different strides in RESA, thus different sliced feature maps
can be gathered without information loss during propagation. - 在RESA中信息以不同的步长传播,因此不同切片的特征图会在传播的时候没有损耗.
- RESA is simple and flexible to be incorporated into other networks.
3)RESA简单灵活能够添加到其他网络.
上采样
- 提出了双边上采样译码器, 一个分支捕捉粗粒度特征,一个分支捕捉细粒度特征
- 粗粒度分支直接应用双线性上采样生成模糊图像,细粒度分支用转置卷积来实现上采样,并且用两个non-bottleneck blocks来修复细节损失.
- 结合两个分支,我们的解码器可以细致地将低分辨率特征图恢复为像素级预测.
数据集
CULane and Tusimple
贡献
- 提出RESA聚合空间信息的结构
- 双边上采样译码器
- 在数据集上效果比较好
相关工作
traditional methods
- . Sun, Tsai, and Chan (2006) tries to detect lanes in HSI color representation
- Yu and Jain (1997) extracts lane boundaries via Hough Transform
deep learning method
- . Huval et al. (2015) are the first to apply deep learning method in lane detection with CNN
- Neven et al. (2018) propose to cast the lane detection problem as an instance segmentation problem
- . Philion (2019) integrates the lane decoding step into the network and draws lanes iteratively without recurrent neural network.
- Self-attention distillation (SAD) is proposed to allow a model to learn from itself and gains substantial improvement without any additional supervision or labels (Hou et al. 2019).
空间信息利用
- ION (Bell et al. 2016) explores the use of spatial Recurrent Neural Networks (RNNs).
- . Liang et al. (2016) constructs Graph LSTM to provide information propagation route for semantic object parsing.
- SCNN (Pan et al. 2018) proposes to generalize traditional layer-by-layer convolutions to slice-by-slice convolutions within feature maps,
方法
- 模型分为三部分: encoder, aggregator, and decoder.
- We select commonly used backbone like ResNet (He et al. 2016), VGG (Simonyan and Zisserman 2015), and etc as our encoder to extract preliminary feature from raw image.
- 用RESA 模块来进行车道线信息的聚合
- 双边上采样译码器
框架设计
网络结构如下图所示:
- 编码器: 常用骨干网络VGG,ResNet等, 进行特征提取,图像变为 1/8,提取初步特征
- RESA模块聚合空间信息,每次迭代四个方向传递信息,RESA模块总共迭代K次来保证每个位置都能接收全图信息.
- 双边上采样模块.每个块上采样2次,把1/8图像恢复.
- After up-sampled by decoder, the output feature map is used to predict existence and probability distribution of each lane.
- For existence prediction, a fully-connected layer is followed and a 0-1 classification.
RESA
优点
- Computationally efficient
- Feature information gathered effectively.
- Easy to be plugged into other network.
双边上采样译码器
- 多数解码器用双线性上采样(bilinear upsampling)过程来进行像素级预测 ,容易获得粗糙的结果,但是可能丢失细节.
- Some methods (Romera et al. 2017) use stacking convolutional operations and deconvolutional operations to obtain refined upsampling results.
- 针对上述动机,我们结合它们的优点,提出了双边上采样解码器。
Coarse grained branch
- 1x1卷积–BN层–双线性上采样–ReLU
Fine detailed branch
transpose convolution with stride 2 – ReLU-- Nonbottleneck block ---- Nonbottleneck block(consists of four 3 × 1 and 1 × 3 convolutions with BN and ReLU)
Experiment
We use SGD (Bottou 2010) with momentum 0.9 and weight decay 1e-4 as the optimizer to train our model and the learning rate is set 1.6e-2 for CULane and 2.5e-2 for Tusimple respectively. We usewarm-up (Doll, Girshick, and Noordhuis 2017) strategy in first 500 batches and then apply polynomial learning rate decay policy (Mishra and Sarawadekar 2019) with power set to 0.9.
- 损失函数 segmentation BCE loss and existence classification CE loss.
- Considering the imbalanced label between background and lane markings, the segmentation loss of background is multiplied by 0.4.
- The batch size is set 8 for CULane and 16 for Tusimple respectively.
- The total number of training epoch is set 50 for TuSimple dataset and 12 for CULane dataset.
- All models are trained with 4 Nvidia 1080ti GPUs. All experiments are implemented with Pytorch.
- In our experiments, we use ResNet (He et al. 2016) and VGG (Simonyan and Zisserman 2014) as backbone. In ResNet, we add extra 1 × 1 convolution to reduce the output channel to 128. The modification of VGG is same as SCNN (Pan et al. 2018)
代码复现
code:https://github/ZJULearning/resa/tree/fe4e7314ebfb2de17f5b75539cb2ed3c28bf6f96
安装环境
conda create -n resa python=3.8 -y
conda activate resa
pip install torch torchvision
pip install -r requirement.txt
数据集
TuSimple
- structure
$TUSIMPLEROOT/clips # data folders
$TUSIMPLEROOT/lable_data_xxxx.json # label json file x3
$TUSIMPLEROOT/test_tasks_0627.json # test tasks json file
$TUSIMPLEROOT/test_label.json # test label json file
- 没有分割标记,从json文件中产生分割标记
python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
# this will generate seg_label directory
测试
作者给出的训练好的模型
(Tusimple: GoogleDrive/BaiduDrive(code:s5ii), CULane: GoogleDrive/BaiduDrive(code:rlwj))
不可视化
python main.py configs/tusimple.py --validate --load_from ./tusimple_resnet34.pth --gpus 0
可视化
python main.py configs/tusimple.py --validate --load_from ./tusimple_resnet34.pth --gpus 0 --view
训练
python main.py configs/tusimple.py --gpus 0 --work_dirs /content/drive/MyDrive/chxsave
configs/tusimple.py
net = dict(
type='RESANet',
)
backbone = dict(
type='ResNetWrapper',
resnet='resnet34',
pretrained=True,
replace_stride_with_dilation=[False, True, True], #空洞卷积
out_conv=True,
fea_stride=8,
)
resa = dict(
type='RESA',
alpha=2.0, #最小学习率
iter=5, #迭代次数
input_channel=128,
conv_stride=9,
)
decoder = 'BUSD'
trainer = dict(
type='RESA'
)
evaluator = dict(
type='Tusimple',
thresh = 0.60
)
optimizer = dict(
type='sgd',
lr=0.020,
weight_decay=1e-4,
momentum=0.9
)
total_iter = 80000
import math
scheduler = dict(
type = 'LambdaLR',
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)
bg_weight = 0.4
img_norm = dict(
mean=[103.939, 116.779, 123.68],
std=[1., 1., 1.]
)
img_height = 368
img_width = 640
cut_height = 160
seg_label = "seg_label"
dataset_path = './data/tusimple'
test_json_file = './data/tusimple/test_label.json'
dataset = dict(
train=dict(
type='TuSimple',
img_path=dataset_path,
data_list='train_val_gt.txt',
),
val=dict(
type='TuSimple',
img_path=dataset_path,
data_list='test_gt.txt'
),
test=dict(
type='TuSimple',
img_path=dataset_path,
data_list='test_gt.txt'
)
)
loss_type = 'cross_entropy'
seg_loss_weight = 1.0
batch_size = 4
workers = 12
num_classes = 6 + 1
ignore_label = 255
epochs = 300
log_interval = 100
eval_ep = 1
save_ep = epochs
log_note = ''
datasets/base_dataset.py
imagepath----listpath
import os.path as osp
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import torchvision
import utils.transforms as tf
from .registry import DATASETS
@DATASETS.register_module
class BaseDataset(Dataset):
def __init__(self, img_path, data_list, list_path='list', cfg=None):
self.cfg = cfg
self.img_path = img_path
self.list_path = osp.join(img_path, list_path)
self.data_list = data_list
self.is_training = ('train' in data_list)
self.img_name_list = []
self.full_img_path_list = []
self.label_list = []
self.exist_list = []
self.transform = self.transform_train() if self.is_training else self.transform_val()
self.init()
def transform_train(self):
raise NotImplementedError()
#尺寸改变
#用均值和标准差对图像归一化
def transform_val(self):
val_transform = torchvision.transforms.Compose([
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
self.cfg.img_norm['std'], (1, ))),
])
return val_transform
#可视化 filepath储存生成的图像
def view(self, img, coords, file_path=None):
for coord in coords:
#取x,y坐标
for x, y in coord:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
#用圆圈代替车道线点
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
if file_path is not None:
if not os.path.exists(osp.dirname(file_path)):
os.makedirs(osp.dirname(file_path))
cv2.imwrite(file_path, img)
def init(self):
raise NotImplementedError()
def __len__(self):
return len(self.full_img_path_list)
def __getitem__(self, idx):
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
#读入图片
img = img[self.cfg.cut_height:, :, :]
#图像切割
if self.is_training:
#训练
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
if len(label.shape) > 2:
label = label[:, :, 0]
label = label.squeeze()
label = label[self.cfg.cut_height:, :]
exist = self.exist_list[idx]
if self.transform:
#图像预处理
img, label = self.transform((img, label))
label = torch.from_numpy(label).contiguous().long()
else:
img, = self.transform((img,))
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
#将tensor的维度换位,通道变换
meta = {'full_img_path': self.full_img_path_list[idx],
'img_name': self.img_name_list[idx]}
data = {'img': img, 'meta': meta}
if self.is_training:
data.update({'label': label, 'exist': exist})
return data
datasets/registry.py
from utils import Registry, build_from_cfg
import torch
DATASETS = Registry('datasets')
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_dataset(split_cfg, cfg):
args = split_cfg.copy()
args.pop('type')
args = args.to_dict()
args['cfg'] = cfg
return build(split_cfg, DATASETS, default_args=args)
def build_dataloader(split_cfg, cfg, is_train=True):
if is_train:
shuffle = True #随机打乱
else:
shuffle = False
dataset = build_dataset(split_cfg, cfg)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size = cfg.batch_size, shuffle = shuffle,
num_workers = cfg.workers, pin_memory = False, drop_last = False)
return data_loader
datasets/tusimple.py
- init
- transform_train
- fix_gap
- is_short
- get_lane
- probmap2lane
import os.path as osp
import numpy as np
import cv2
import torchvision
import utils.transforms as tf
from .base_dataset import BaseDataset
from .registry import DATASETS
@DATASETS.register_module
class TuSimple(BaseDataset):
def __init__(self, img_path, data_list, cfg=None):
super().__init__(img_path, data_list, 'seg_label/list', cfg)
#训练集图像预处理
def transform_train(self):
input_mean = self.cfg.img_norm['mean']
train_transform = torchvision.transforms.Compose([
tf.GroupRandomRotation(),
tf.GroupRandomHorizontalFlip(),
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
self.cfg.img_norm['std'], (1, ))),
])
return train_transform
def init(self):
with open(osp.join(self.list_path, self.data_list)) as f:
for line in f:
line_split = line.strip().split(" ")
self.img_name_list.append(line_split[0])
self.full_img_path_list.append(self.img_path + line_split[0])
if not self.is_training:
continue
self.label_list.append(self.img_path + line_split[1])
self.exist_list.append(
np.array([int(line_split[2]), int(line_split[3]),
int(line_split[4]), int(line_split[5]),
int(line_split[6]), int(line_split[7])
]))
def fix_gap(self, coordinate):
if any(x > 0 for x in coordinate):
start = [i for i, x in enumerate(coordinate) if x > 0][0]
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
lane = coordinate[start:end+1]
if any(x < 0 for x in lane):
gap_start = [i for i, x in enumerate(
lane[:-1]) if x > 0 and lane[i+1] < 0]
gap_end = [i+1 for i,
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
gap_id = [i for i, x in enumerate(lane) if x < 0]
if len(gap_start) == 0 or len(gap_end) == 0:
return coordinate
for id in gap_id:
for i in range(len(gap_start)):
if i >= len(gap_end):
return coordinate
if id > gap_start[i] and id < gap_end[i]:
gap_width = float(gap_end[i] - gap_start[i])
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
gap_end[i] - id) / gap_width * lane[gap_start[i]])
if not all(x > 0 for x in lane):
print("Gaps still exist!")
coordinate[start:end+1] = lane
return coordinate
def is_short(self, lane):
start = [i for i, x in enumerate(lane) if x > 0]
if not start:
return 1
else:
return 0
def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None):
"""
Arguments:
----------
prob_map: prob map for single lane, np array size (h, w)
resize_shape: reshape size target, (H, W)
Return:
----------
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
"""
if resize_shape is None:
resize_shape = prob_map.shape
h, w = prob_map.shape
H, W = resize_shape
H -= self.cfg.cut_height
coords = np.zeros(pts)
coords[:] = -1.0
for i in range(pts):
y = int((H - 10 - i * y_px_gap) * h / H)
if y < 0:
break
line = prob_map[y, :]
id = np.argmax(line)
if line[id] > thresh:
coords[i] = int(id / w * W)
if (coords > 0).sum() < 2:
coords = np.zeros(pts)
self.fix_gap(coords)
#print(coords.shape)
return coords
def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
"""
Arguments:
----------
seg_pred: np.array size (5, h, w)
resize_shape: reshape size target, (H, W)
exist: list of existence, e.g. [0, 1, 1, 0]
smooth: whether to smooth the probability or not
y_px_gap: y pixel gap for sampling
pts: how many points for one lane
thresh: probability threshold
Return:
----------
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
"""
if resize_shape is None:
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
_, h, w = seg_pred.shape
H, W = resize_shape
coordinates = []
for i in range(self.cfg.num_classes - 1):
prob_map = seg_pred[i + 1]
if smooth:
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
if self.is_short(coords):
continue
coordinates.append(
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
range(pts)])
if len(coordinates) == 0:
coords = np.zeros(pts)
coordinates.append(
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
range(pts)])
#print(coordinates)
return coordinates
utils貌似是为configs服务的
utils/registry.py
import inspect
import six
# borrow from mmdetection
def is_str(x):
"""Whether the input is an string instance."""
return isinstance(x, six.string_types)
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = {}
obj_type = cfg.type
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
models/decoder.py
from torch import nn
import torch.nn.functional as F
class PlainDecoder(nn.Module):
def __init__(self, cfg):
super(PlainDecoder, self).__init__()
self.cfg = cfg
self.dropout = nn.Dropout2d(0.1)
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
def forward(self, x):
x = self.dropout(x)
x = self.conv8(x)
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
mode='bilinear', align_corners=False)
#双线性插值
return x
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class non_bottleneck_1d(nn.Module):
def __init__(self, chann, dropprob, dilated):
super().__init__()
self.conv3x1_1 = nn.Conv2d(
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
self.conv1x3_1 = nn.Conv2d(
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
dilation=(dilated, 1))
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
dilation=(1, dilated))
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
self.dropout = nn.Dropout2d(dropprob)
def forward(self, input):
output = self.conv3x1_1(input)
output = F.relu(output)
output = self.conv1x3_1(output)
output = self.bn1(output)
output = F.relu(output)
output = self.conv3x1_2(output)
output = F.relu(output)
output = self.conv1x3_2(output)
output = self.bn2(output)
if (self.dropout.p != 0):
output = self.dropout(output)
# +input = identity (residual connection)
return F.relu(output + input)
#Fine detailed branch + Coarse grained branch
class UpsamplerBlock(nn.Module):
def __init__(self, ninput, noutput, up_width, up_height):
super().__init__()
self.conv = nn.ConvTranspose2d(
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
self.follows = nn.ModuleList()
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
# interpolate
self.up_width = up_width
self.up_height = up_height
self.interpolate_conv = conv1x1(ninput, noutput)
self.interpolate_bn = nn.BatchNorm2d(
noutput, eps=1e-3, track_running_stats=True)
def forward(self, input):
output = self.conv(input)
output = self.bn(output)
out = F.relu(output)
for follow in self.follows:
out = follow(out)
interpolate_output = self.interpolate_conv(input)
interpolate_output = self.interpolate_bn(interpolate_output)
interpolate_output = F.relu(interpolate_output)
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
mode='bilinear', align_corners=False)
return out + interpolate
#Bilateral Up-sampling Decoder
class BUSD(nn.Module):
def __init__(self, cfg):
super().__init__()
img_height = cfg.img_height
img_width = cfg.img_width
num_classes = cfg.num_classes
self.layers = nn.ModuleList()
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
up_height=int(img_height)//4, up_width=int(img_width)//4))
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
up_height=int(img_height)//2, up_width=int(img_width)//2))
self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
up_height=int(img_height)//1, up_width=int(img_width)//1))
self.output_conv = conv1x1(16, num_classes)
def forward(self, input):
output = input
for layer in self.layers:
output = layer(output)
output = self.output_conv(output)
return output
models/registry.py
from utils import Registry, build_from_cfg
NET = Registry('net')
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_net(cfg):
return build(cfg.net, NET, default_args=dict(cfg=cfg))
models/resa.py
import torch.nn as nn
import torch
import torch.nn.functional as F
from models.registry import NET
from .resnet import ResNetWrapper
from .decoder import BUSD, PlainDecoder
class RESA(nn.Module):
def __init__(self, cfg):
super(RESA, self).__init__()
self.iter = cfg.resa.iter
chan = cfg.resa.input_channel
fea_stride = cfg.backbone.fea_stride
self.height = cfg.img_height // fea_stride
self.width = cfg.img_width // fea_stride
self.alpha = cfg.resa.alpha
conv_stride = cfg.resa.conv_stride
for i in range(self.iter):
#nn.Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True))
conv_vert1 = nn.Conv2d(
chan, chan, (1, conv_stride),
padding=(0, conv_stride//2), groups=1, bias=False)
conv_vert2 = nn.Conv2d(
chan, chan, (1, conv_stride),
padding=(0, conv_stride//2), groups=1, bias=False)
setattr(self, 'conv_d'+str(i), conv_vert1)
setattr(self, 'conv_u'+str(i), conv_vert2)
conv_hori1 = nn.Conv2d(
chan, chan, (conv_stride, 1),
padding=(conv_stride//2, 0), groups=1, bias=False)
conv_hori2 = nn.Conv2d(
chan, chan, (conv_stride, 1),
padding=(conv_stride//2, 0), groups=1, bias=False)
setattr(self, 'conv_r'+str(i), conv_hori1)
setattr(self, 'conv_l'+str(i), conv_hori2)
# //是整除 **是幂
idx_d = (torch.arange(self.height) + self.height //
2**(self.iter - i)) % self.height
setattr(self, 'idx_d'+str(i), idx_d)
idx_u = (torch.arange(self.height) - self.height //
2**(self.iter - i)) % self.height
setattr(self, 'idx_u'+str(i), idx_u)
idx_r = (torch.arange(self.width) + self.width //
2**(self.iter - i)) % self.width
setattr(self, 'idx_r'+str(i), idx_r)
idx_l = (torch.arange(self.width) - self.width //
2**(self.iter - i)) % self.width
setattr(self, 'idx_l'+str(i), idx_l)
def forward(self, x):
x = x.clone()
for direction in ['d', 'u']:
for i in range(self.iter):
conv = getattr(self, 'conv_' + direction + str(i))
idx = getattr(self, 'idx_' + direction + str(i))
x.add_(self.alpha * F.relu(conv(x[..., idx, :])))
for direction in ['r', 'l']:
for i in range(self.iter):
conv = getattr(self, 'conv_' + direction + str(i))
idx = getattr(self, 'idx_' + direction + str(i))
x.add_(self.alpha * F.relu(conv(x[..., idx])))
return x
#exist 输出
class ExistHead(nn.Module):
def __init__(self, cfg=None):
super(ExistHead, self).__init__()
self.cfg = cfg
self.dropout = nn.Dropout2d(0.1)
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
stride = cfg.backbone.fea_stride * 2
self.fc9 = nn.Linear(
int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
self.fc10 = nn.Linear(128, cfg.num_classes-1)
def forward(self, x):
x = self.dropout(x)
x = self.conv8(x)
x = F.softmax(x, dim=1)
x = F.avg_pool2d(x, 2, stride=2, padding=0)
x = x.view(-1, x.numel() // x.shape[0])
x = self.fc9(x)
x = F.relu(x)
x = self.fc10(x)
x = torch.sigmoid(x)
return x
@NET.register_module
class RESANet(nn.Module):
def __init__(self, cfg):
super(RESANet, self).__init__()
self.cfg = cfg
self.backbone = ResNetWrapper(cfg)
self.resa = RESA(cfg)
self.decoder = eval(cfg.decoder)(cfg)
self.heads = ExistHead(cfg)
def forward(self, batch):
fea = self.backbone(batch)
fea = self.resa(fea)
seg = self.decoder(fea)
exist = self.heads(fea)
output = {'seg': seg, 'exist': exist}
return output
models/resnet.py
import torch
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
# This code is borrow from torchvision.
model_urls = {
'resnet18': 'https://download.pytorch/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
# if dilation > 1:
# raise NotImplementedError(
# "Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, dilation=dilation)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNetWrapper(nn.Module):
def __init__(self, cfg):
super(ResNetWrapper, self).__init__()
self.cfg = cfg
self.in_channels = [64, 128, 256, 512]
if 'in_channels' in cfg.backbone:
self.in_channels = cfg.backbone.in_channels
self.model = eval(cfg.backbone.resnet)(
pretrained=cfg.backbone.pretrained,
replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)
self.out = None
if cfg.backbone.out_conv:
out_channel = 512
for chan in reversed(self.in_channels):
if chan < 0: continue
out_channel = chan
break
self.out = conv1x1(
out_channel * self.model.expansion, 128)
def forward(self, x):
x = self.model(x)
if self.out:
x = self.out(x)
return x
class ResNet(nn.Module):
def __init__(self, block, layers, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, in_channels=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.in_channels = in_channels
self.layer1 = self._make_layer(block, in_channels[0], layers[0])
self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
if in_channels[3] > 0:
self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.expansion = block.expansion
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.in_channels[3] > 0:
x = self.layer4(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict, strict=False)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
runner/evaluator/tusimple/tusimple.py
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.logger import get_logger
from runner.registry import EVALUATOR
import json
import os
import cv2
from .lane import LaneEval
def split_path(path):
"""split path tree into list"""
folders = []
while True:
path, folder = os.path.split(path)
if folder != "":
folders.insert(0, folder)
else:
if path != "":
folders.insert(0, path)
break
return folders
@EVALUATOR.register_module
class Tusimple(nn.Module):
def __init__(self, cfg):
super(Tusimple, self).__init__()
self.cfg = cfg
exp_dir = os.path.join(self.cfg.work_dir, "output")
if not os.path.exists(exp_dir):
os.mkdir(exp_dir)
self.out_path = os.path.join(exp_dir, "coord_output")
if not os.path.exists(self.out_path):
os.mkdir(self.out_path) #坐标输出
self.dump_to_json = []
self.thresh = cfg.evaluator.thresh
self.logger = get_logger('resa') #日志
if cfg.view:
self.view_dir = os.path.join(self.cfg.work_dir, 'vis') #可视化
def evaluate_pred(self, dataset, seg_pred, exist_pred, batch):
img_name = batch['meta']['img_name']
img_path = batch['meta']['full_img_path']
for b in range(len(seg_pred)): #看车道线是否存在
seg = seg_pred[b]
exist = [1 if exist_pred[b, i] >
0.5 else 0 for i in range(self.cfg.num_classes-1)]
lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh)
for i in range(len(lane_coords)):
lane_coords[i] = sorted(
lane_coords[i], key=lambda pair: pair[1])
path_tree = split_path(img_name[b])
save_dir, save_name = path_tree[-3:-1], path_tree[-1]
save_dir = os.path.join(self.out_path, *save_dir)
save_name = save_name[:-3] + "lines.txt"
save_name = os.path.join(save_dir, save_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
with open(save_name, "w") as f:
for l in lane_coords:
for (x, y) in l:
print("{} {}".format(x, y), end=" ", file=f)
print(file=f)
json_dict = {}
json_dict['lanes'] = []
json_dict['h_sample'] = []
json_dict['raw_file'] = os.path.join(*path_tree[-4:])
json_dict['run_time'] = 0
for l in lane_coords:
if len(l) == 0:
continue
json_dict['lanes'].append([])
for (x, y) in l:
json_dict['lanes'][-1].append(int(x))
for (x, y) in lane_coords[0]:
json_dict['h_sample'].append(y)
self.dump_to_json.append(json.dumps(json_dict))
if self.cfg.view:
img = cv2.imread(img_path[b])
new_img_name = img_name[b].replace('/', '_')
save_dir = os.path.join(self.view_dir, new_img_name)
dataset.view(img, lane_coords, save_dir)
def evaluate(self, dataset, output, batch):
seg_pred, exist_pred = output['seg'], output['exist']
seg_pred = F.softmax(seg_pred, dim=1)
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
self.evaluate_pred(dataset, seg_pred, exist_pred, batch)
def summarize(self):
best_acc = 0
output_file = os.path.join(self.out_path, 'predict_test.json')
with open(output_file, "w+") as f:
for line in self.dump_to_json:
print(line, end="\n", file=f)
eval_result, acc = LaneEval.bench_one_submit(output_file,
self.cfg.test_json_file)
self.logger.info(eval_result)
self.dump_to_json = []
best_acc = max(acc, best_acc)
return best_acc
runner/logger.py
- 日志记录
import logging
logger_initialized = {}
def get_logger(name, log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
runnner/net_utils.py
import torch
import os
from torch import nn
import numpy as np
import torch.nn.functional
from termcolor import colored
from .logger import get_logger
#如果是最好的就保存下来
def save_model(net, optim, scheduler, recorder, is_best=False):
model_dir = os.path.join(recorder.work_dir, 'ckpt')
os.system('mkdir -p {}'.format(model_dir)) #-p用来创建多级文件夹
epoch = recorder.epoch
ckpt_name = 'best' if is_best else epoch
torch.save({
'net': net.state_dict(),
'optim': optim.state_dict(),
'scheduler': scheduler.state_dict(),
'recorder': recorder.state_dict(),
'epoch': epoch
}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))
def load_network_specified(net, model_dir, logger=None):
pretrained_net = torch.load(model_dir)['net']
net_state = net.state_dict()
state = {}
for k, v in pretrained_net.items():
if k not in net_state.keys() or v.size() != net_state[k].size():
if logger:
logger.info('skip weights: ' + k)
continue
state[k] = v
net.load_state_dict(state, strict=False)
def load_network(net, model_dir, finetune_from=None, logger=None):
if finetune_from:
if logger:
logger.info('Finetune model from: ' + finetune_from)
load_network_specified(net, finetune_from, logger)
return
pretrained_model = torch.load(model_dir)
net.load_state_dict(pretrained_model['net'], strict=True)
runnner/optimizer.py
import torch
_optimizer_factory = {
'adam': torch.optim.Adam,
'sgd': torch.optim.SGD
}
def build_optimizer(cfg, net):
params = []
lr = cfg.optimizer.lr
weight_decay = cfg.optimizer.weight_decay
for key, value in net.named_parameters():
if not value.requires_grad:
continue
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if 'adam' in cfg.optimizer.type:
optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)
else:
optimizer = _optimizer_factory[cfg.optimizer.type](
params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)
return optimizer
runnner/recorder.py
from collections import deque, defaultdict
import torch
import os
import datetime
from .logger import get_logger
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
#存储一系列值,并且提供median和avg
def __init__(self, window_size=20):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def update(self, value):
self.deque.append(value)
self.count += 1
self.total += value
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque))
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
class Recorder(object):
def __init__(self, cfg):
self.cfg = cfg
self.work_dir = self.get_work_dir() #具体的年月日和参数文件夹
cfg.work_dir = self.work_dir #更新到cfg
self.log_path = os.path.join(self.work_dir, 'log.txt')#日志
self.logger = get_logger('resa', self.log_path)
self.logger.info('Config: \n' + cfg.text) #把cfg写到文件
# scalars
self.epoch = 0
self.step = 0
self.loss_stats = defaultdict(SmoothedValue)
self.batch_time = SmoothedValue()
self.data_time = SmoothedValue()
self.max_iter = self.cfg.total_iter
self.lr = 0.
def get_work_dir(self): #记录文件夹创建
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') #根据时间和参数创建文件夹
hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size) #根据时间和参数创建文件夹
work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)
if not os.path.exists(work_dir):
os.makedirs(work_dir)
return work_dir
def update_loss_stats(self, loss_dict):
for k, v in loss_dict.items():
self.loss_stats[k].update(v.detach().cpu())
def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
self.logger.info(self)
# self.write(str(self))
def write(self, content):
with open(self.log_path, 'a+') as f:
f.write(content)
f.write('\n')
def state_dict(self):
scalar_dict = {}
scalar_dict['step'] = self.step
return scalar_dict
def load_state_dict(self, scalar_dict):
self.step = scalar_dict['step']
def __str__(self):
loss_state = []
for k, v in self.loss_stats.items():
loss_state.append('{}: {:.4f}'.format(k, v.avg))
loss_state = ' '.join(loss_state)
recording_state = ' '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])
eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)
def build_recorder(cfg):
return Recorder(cfg)
runner/registry.py
from utils import Registry, build_from_cfg
TRAINER = Registry('trainer')
EVALUATOR = Registry('evaluator')
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_trainer(cfg):
return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))
def build_evaluator(cfg):
return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))
runner/resa_trainer.py ----loss计算
- X - GT 分割图像, Y - Pred 分割图像.
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.registry import TRAINER
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return (1-d).mean()
@TRAINER.register_module
class RESA(nn.Module):
def __init__(self, cfg):
super(RESA, self).__init__()
self.cfg = cfg
self.loss_type = cfg.loss_type
if self.loss_type == 'cross_entropy':
weights = torch.ones(cfg.num_classes)
weights[0] = cfg.bg_weight
weights = weights.cuda()
self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
weight=weights).cuda()
self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
def forward(self, net, batch):
output = net(batch['img'])
loss_stats = {}
loss = 0.
if self.loss_type == 'dice_loss':
target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
seg_loss = dice_loss(F.softmax(
output['seg'], dim=1)[:, 1:], target[:, 1:])
else:
seg_loss = self.criterion(F.log_softmax(
output['seg'], dim=1), batch['label'].long())
loss += seg_loss * self.cfg.seg_loss_weight
loss_stats.update({'seg_loss': seg_loss})
if 'exist' in output:
exist_loss = 0.1 * \
self.criterion_exist(output['exist'], batch['exist'].float())
loss += exist_loss
loss_stats.update({'exist_loss': exist_loss})
ret = {'loss': loss, 'loss_stats': loss_stats}
return ret
runner/scheduler.py 学习率衰减策略Lambda LR
import torch
import math
_scheduler_factory = {
'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
}
def build_scheduler(cfg, optimizer):
assert cfg.scheduler.type in _scheduler_factory
cfg_cp = cfg.scheduler.copy()
cfg_cp.pop('type')
scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)
return scheduler
runner/runner.py
import time
import torch
import numpy as np
from tqdm import tqdm
import pytorch_warmup as warmup
from models.registry import build_net
from .registry import build_trainer, build_evaluator
from .optimizer import build_optimizer
from .scheduler import build_scheduler
from datasets import build_dataloader
from .recorder import build_recorder
from .net_utils import save_model, load_network
class Runner(object):
def __init__(self, cfg):
self.cfg = cfg
self.recorder = build_recorder(self.cfg) #创建日志
self.net = build_net(self.cfg) #创建模型
self.net = torch.nn.parallel.DataParallel(
self.net, device_ids = range(self.cfg.gpus)).cuda() #多GPU
self.recorder.logger.info('Network: \n' + str(self.net)) #把网络结构写进日志
self.resume() #如果需要恢复模型
self.optimizer = build_optimizer(self.cfg, self.net)
self.scheduler = build_scheduler(self.cfg, self.optimizer)
self.evaluator = build_evaluator(self.cfg)
self.warmup_scheduler = warmup.LinearWarmup(
self.optimizer, warmup_period=5000)
self.metric = 0.
def resume(self):
if not self.cfg.load_from and not self.cfg.finetune_from:
return
load_network(self.net, self.cfg.load_from,
finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)
def to_cuda(self, batch): #数据送到cuda
for k in batch:
if k == 'meta':
continue
batch[k] = batch[k].cuda()
return batch
def train_epoch(self, epoch, train_loader):
self.net.train() #训练模式 设计BN和dropout
end = time.time()
max_iter = len(train_loader)
for i, data in enumerate(train_loader):
if self.recorder.step >= self.cfg.total_iter: #80000
break
date_time = time.time() - end
self.recorder.step += 1
data = self.to_cuda(data)
output = self.trainer.forward(self.net, data) #前向
self.optimizer.zero_grad() #梯度清零
loss = output['loss']
loss.backward()
self.optimizer.step()
self.scheduler.step()
self.warmup_scheduler.dampen()
batch_time = time.time() - end
end = time.time()
self.recorder.update_loss_stats(output['loss_stats'])
self.recorder.batch_time.update(batch_time) #批时间
self.recorder.data_time.update(date_time) #数据时间
if i % self.cfg.log_interval == 0 or i == max_iter - 1:
lr = self.optimizer.param_groups[0]['lr']
self.recorder.lr = lr
self.recorder.record('train')
def train(self):
self.recorder.logger.info('start training...')
self.trainer = build_trainer(self.cfg)
train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)
val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)
for epoch in range(self.cfg.epochs):
self.recorder.epoch = epoch
self.train_epoch(epoch, train_loader)
if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:
self.save_ckpt()
if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:
self.validate(val_loader)
if self.recorder.step >= self.cfg.total_iter:
break
def validate(self, val_loader):
self.net.eval()
for i, data in enumerate(tqdm(val_loader, desc=f'Validate')):
data = self.to_cuda(data)
with torch.no_grad():
output = self.net(data['img'])
self.evaluator.evaluate(val_loader.dataset, output, data)
metric = self.evaluator.summarize()
if not metric:
return
if metric > self.metric:
self.metric = metric
self.save_ckpt(is_best=True)
self.recorder.logger.info('Best metric: ' + str(self.metric))
def save_ckpt(self, is_best=False):
save_model(self.net, self.optimizer, self.scheduler,
self.recorder, is_best)
main.py
import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner
from datasets import build_dataloader
def main():
args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)
cfg = Config.fromfile(args.config)
cfg.gpus = len(args.gpus)
cfg.load_from = args.load_from
cfg.finetune_from = args.finetune_from
cfg.view = args.view
cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type
cudnn.benchmark = True
cudnn.fastest = True
runner = Runner(cfg)
if args.validate:
val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
runner.validate(val_loader)
else:
runner.train()
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path') #config
parser.add_argument(
'--work_dirs', type=str, default='work_dirs', #workdir
help='work dirs')
parser.add_argument(
'--load_from', default=None,
help='the checkpoint file to resume from') #resume
parser.add_argument(
'--finetune_from', default=None,
help='whether to finetune from the checkpoint') #finetune
parser.add_argument(
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
parser.add_argument(
'--view',
action='store_true',
help='whether to show visualization result')
parser.add_argument('--gpus', nargs='+', type=int, default='0')
parser.add_argument('--seed', type=int,
default=None, help='random seed')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()
colab下代码复现
!wget https://s3.us-east-2.amazonaws.com/benchmark-frontend/datasets/1/test_set.zip
!7za x test_set.zip -o/content/drive/MyDrive/chx
!wget https://s3.us-east-2.amazonaws.com/benchmark-frontend/datasets/1/train_set.zip
!7za x train_set.zip -o/content/drive/MyDrive/chx
! /opt/bin/nvidia-smi
cd /content
!git clone https://github.com/ZJULearning/resa.git
cd ./resa
!pip install -r requirement.txt
!pip install pytorch_warmup
!pip install yapf
dataset_path = '/content/drive/MyDrive/chx'
test_json_file = '/content/drive/MyDrive/chx/test_label.json'
!python main.py configs/tusimple.py --gpus 0 --work_dirs /content/drive/MyDrive/chxsave
版权声明:本文标题:RESA: Recurrent Feature-Shift Aggregator for Lane Detection 论文阅读+代码复现(车道线检测) 内容由热心网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:https://m.elefans.com/dianzi/1729534547a1205246.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论