0%

第十届中国“软件杯” A4林业有害生物智能识别国家二等奖总结

软件杯经过了几个月的竞赛和答辩过后,总算是获得了一个相对比较满意的成绩。

我们组的名字叫P5天下第一,回过头看这个名字实在是有种说不出的感觉(索尼罪大恶极),我的P5都还在PS4里吃灰,结果后面直接上了XGP,还是P5R,只能说《Only On PlayStation For A While》。

回到正题。首先还是来讲讲数据库和后端的方案吧。

下面是整体的架构图设计

911cG.png

Pest-Identification-Provider

数据库其实很简单,就是一个用户表,还有害虫的目科属种的信息表,以及关联表。

91NO1.png

后端采用spring cloud + nacos的方案。分为provider和consumer,然后通过spring gateway网关进行负载均衡。

9iZGD.jpg

YoloV4图像识别模块也通过gunicorn实现了负载均衡。

用户登录后台采用的是jwtToken + Spring security实现权限校验。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package com.rjgc.configs;

/**
* @author zhaoyunjie
* @date 2021-04-09 18:39
*/
@EnableWebSecurity
@Configuration
public class SpringSecurityConfig extends WebSecurityConfigurerAdapter {

@Autowired
private CustomUserDetainsService userDetainsService;

@Autowired
private CsrfProperties csrfProperties;

@Autowired
private CustomCorsFilter customCorsFilter;

@Autowired
private JwtTokenUtils jwtTokenUtils;

@Autowired
@Qualifier("userServiceImpl")
private UserService userService;

@Override
protected void configure(HttpSecurity http) throws Exception {
http.cors().and()
//添加cors响应头
.addFilterAfter(customCorsFilter, HeaderWriterFilter.class)
//添加token登录过滤器
.addFilterBefore(new TokenLoginFilter(authenticationManager(), jwtTokenUtils), UsernamePasswordAuthenticationFilter.class)
//添加token认证过滤器
.addFilterBefore(new TokenAuthenticationFilter(authenticationManager(), jwtTokenUtils, userService), BasicAuthenticationFilter.class).httpBasic()
.and()
//关闭session
.sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS)
.and()
.authorizeRequests()
.antMatchers(HttpMethod.OPTIONS).permitAll()
.antMatchers("/login").permitAll()
.antMatchers(HttpMethod.GET, "/species/*").permitAll()
.antMatchers("/imgFake").permitAll()
.anyRequest().hasRole("admin");
if (csrfProperties.getCsrfDisabled()) {
http.csrf().disable();
}
}

@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth.userDetailsService(userDetainsService).passwordEncoder(passwordEncoder());
}

@Override
public void configure(WebSecurity web) throws Exception {
web.ignoring()
.antMatchers("/v2/api-docs",
"/swagger-resources/configuration/ui",
"/swagger-resources",
"/swagger-resources/configuration/security",
"/swagger-ui.html",
"/webjars/**");
}

@Bean
public BCryptPasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
}

在spring security中添加了两个过滤器,一个Token登录过滤器,一个Token认证过滤器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package com.rjgc.filters;

/**
* @author zhaoyunjie
* @date 2021-04-15 19:56
*/
public class TokenLoginFilter extends UsernamePasswordAuthenticationFilter {
private final AuthenticationManager authenticationManager;
private final JwtTokenUtils jwtTokenUtils;

public TokenLoginFilter(AuthenticationManager authenticationManager, JwtTokenUtils jwtTokenUtils) {
this.authenticationManager = authenticationManager;
this.jwtTokenUtils = jwtTokenUtils;
this.setPostOnly(false);
this.setRequiresAuthenticationRequestMatcher(new AntPathRequestMatcher("/login", "POST"));
}

@Override
public Authentication attemptAuthentication(HttpServletRequest req, HttpServletResponse res) throws AuthenticationException {
Map<String, String[]> map = req.getParameterMap();
String username = map.get("username")[0];
String password = map.get("password")[0];
User user = new User(username, password, new ArrayList<>());
return authenticationManager.authenticate(new UsernamePasswordAuthenticationToken(user.getUsername(), user.getPassword(), new ArrayList<>()));
}

/**
* 登录成功
*/
@Override
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication auth) {
User user = (User) auth.getPrincipal();
String token = jwtTokenUtils.generateToken(user.getUsername());
HashMap<String, Object> map = new HashMap<>();
map.put("token", token);
map.put("username", user.getUsername());
ResponseUtils.out(response, ResBody.success(map));
}

/**
* 登录失败
*/
@Override
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) {
ResponseUtils.out(response, ResBody.error(new BizException(ExceptionsEnum.LOGIN_FAILED)));
}
}

TokenLoginFilter中有一个JwtTokenUtils,当尝试登录时,将会先调用attemptAuthentication函数,对比账号和密码后将会调用successfulAuthentication函数根据用户名生成一个token,然后返回到前端。

密码采用默认的BCryptPasswordEncoder进行加密。

1
2
3
4
@Bean
public BCryptPasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}

数据库采用的是MySQL + mybatis plus。

整体就是采用MVC架构,在controller中定义接口,然后在service中进行处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package com.rjgc.controller;

/**
* @author zhaoyunjie
* @date 2021-04-08 13:18
*/
@RestController
@RequestMapping("/family")
public class FamilyController {

@Qualifier("familyServiceImpl")
@Autowired
private FamilyService familyService;

@Autowired
@Qualifier("familyGenusVoServiceImpl")
private FamilyGenusVoService familyGenusVoService;

......

/**
* 通过id查询科
*
* @param id id
* @return 结果list
*/
@GetMapping
@ApiOperation("根据id查询科")
public ResBody<Map<String, Object>> selectFamiliesById(@RequestParam int id) {
return ResBody.success(familyVoService.selectFamiliesById(id));
}

/**
* 查询所有科
*
* @param pageNum 当前查询起始页
* @param pageSize 需要的页数
* @return 结果list
*/
@GetMapping("all")
@ApiOperation("查询所有科")
public ResBody<Map<String, Object>> selectAllFamilies(@RequestParam int pageNum, @RequestParam int pageSize) {
return ResBody.success(familyVoService.selectAllFamilies(pageNum, pageSize));
}

@GetMapping("name")
@ApiOperation("根据名称查询")
public ResBody<Map<String, Object>> selectFamiliesByName(@RequestParam int pageNum, @RequestParam int pageSize, @RequestParam String name) {
return ResBody.success(familyVoService.selectFamiliesByName(pageNum, pageSize, name));
}

......

在这里定义了一个统一的返回值ResBody,这样在前端处理会更方便,并且不需要处理null值的问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package com.rjgc.exceptions;

import lombok.Data;

import java.util.HashMap;
import java.util.Map;

/**
* @author zhaoyunjie
* @date 2021-04-09 10:42
*/
@Data
public class ResBody<T> {
/**
* 响应代码
*/
private int code;

/**
* 响应消息
*/
private String message;

/**
* 响应结果
*/
private Object result;

private static Map<String, Object> RESULT_HOLDER = new HashMap<>();

static {
RESULT_HOLDER.put("pages", 1);
RESULT_HOLDER.put("data", "");
}

/**
* 成功
*
* @return resBody
*/
public static <T> ResBody<T> success() {
return success(null);
}

/**
* 成功
*
* @param data data
* @return resBody
*/
public static <T> ResBody<T> success(T data) {
ResBody<T> rb = new ResBody<>();
rb.setCode(ExceptionsEnum.SUCCESS.getResCode());
rb.setMessage(ExceptionsEnum.SUCCESS.getResMsg());
rb.setResult(data);
return rb;
}

/**
* 失败
*/
public static <T> ResBody<T> error(BizException errorInfo) {
ResBody<T> rb = new ResBody<>();
rb.setCode(errorInfo.getExp().getResCode());
rb.setMessage(errorInfo.getExp().getResMsg());
rb.setResult(RESULT_HOLDER);
return rb;
}

public static <T> ResBody<T> error(int errorCode, String errorInfo) {
ResBody<T> rb = new ResBody<>();
rb.setCode(errorCode);
rb.setMessage(errorInfo);
rb.setResult(RESULT_HOLDER);
return rb;
}

@Override
public String toString() {
return "ResBody{" +
"code=" + code +
", message='" + message + '\'' +
", result=" + result +
'}';
}
}

项目中还使用了全局异常处理中心来对不同的异常进行处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package com.rjgc.exceptions;

import lombok.extern.slf4j.Slf4j;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import javax.servlet.http.HttpServletRequest;

/**
* @author zhaoyunjie
* @date 2021-04-09 10:38
* <p>
* 全局异常处理
*/
@RestControllerAdvice
@Slf4j
public class GlobalExceptionHandler {

@ExceptionHandler(value = BizException.class)
public ResBody<BizException> bizExceptionHandler(HttpServletRequest req, BizException e) {
log.debug("Message: " + e.getMessage());
log.debug("Cause: " + e.getCause());
return ResBody.error(e);
}

@ExceptionHandler(value = NullPointerException.class)
public ResBody<NullPointerException> nullPointerExceptionHandler(HttpServletRequest req, NullPointerException e) {
e.printStackTrace();
log.debug("Message: " + e.getMessage());
log.debug("Cause: " + e.getCause());
return ResBody.error(50001, "空指针异常");
}

@ExceptionHandler(value = DuplicateKeyException.class)
public ResBody<DuplicateKeyException> duplicateKeyExceptionHandler(HttpServletRequest req, DuplicateKeyException e) {
log.debug("Message: " + e.getMessage());
log.debug("Cause: " + e.getCause());
return ResBody.error(new BizException(ExceptionsEnum.INVALID_ID));
}

@ExceptionHandler(value = Exception.class)
public ResBody<Exception> exceptionHandler(HttpServletRequest req, Exception e) {
e.printStackTrace();
log.error("Message: " + e.getMessage());
log.error("Cause: " + e.getCause());
return ResBody.error(new BizException(ExceptionsEnum.DATABASE_FAILED));
}
}

以上基本上是provider的设计方案。

Pest-Identification-Consumer

接下来在consumer端,其实就是对provider的进一步封装。通过OpenFegin来进行远程调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@FeignClient(value = "pest-identification-provider", configuration = FeignConfig.class)
public interface FamilyClient {

@GetMapping("/family")
ResBody<Map<String, Object>> selectFamiliesById(@RequestParam int id);

@GetMapping("/family/all")
ResBody<Map<String, Object>> selectAllFamilies(@RequestParam int pageNum, @RequestParam int pageSize);

@GetMapping("/family/name")
ResBody<Map<String, Object>> selectFamiliesByName(@RequestParam int pageNum, @RequestParam int pageSize, @RequestParam String name);

@PostMapping("/family/{orderId}")
ResBody<Integer> insertFamily(@PathVariable int orderId, @RequestBody Family family);

@PutMapping("/family")
ResBody<Integer> updateFamily(@RequestParam int orderId, @RequestBody Family newFamily);

@DeleteMapping("/family")
ResBody<Integer> deleteFamilyById(@RequestParam int id);
}

Pest-Identification-Gateway

网关这里其实Spring已经给出了很方便的解决方案。只需要在配置文件中配好负载均衡,在nacos就可以实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
server:
port: 8081

spring:
application:
name: pest-identification-gateway

cloud:
nacos:
discovery:
server-addr: localhost:8848
config:
enabled: false

gateway:
discovery:
locator:
enabled: true
lower-case-service-id: true
routes:
- id: consumer
uri: lb://pest-identification-consumer
predicates:
- Path=/**
httpclient:
connect-timeout: 3000
response-timeout: 10000

YoloV4网络

这次竞赛有一个很大的难题,也就是图像识别的问题。这里我采用的是YoloV4框架来进行目标检测。由于这次竞赛还需要实现离线识别,因此在安卓客户端使用了YoloV5框架来实现断网情况下的识别。

首先说说YoloV4框架,在这次竞赛中,我使用了flask框架来对外提供识别api。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@app.route('/startPredict', methods=['POST'])
def start_predict():
"""
开始对img/${ip}中的图片进行预测,返回预测结果
:return:
"""
ret = {}
each_statistics: dict
if not pictures:
return resp_utils.error('还未上传文件')
for each in pictures:
info = load_index(each)
if not info:
each_statistics, base64_str = predict.predict_img(each)
ret[os.path.basename(each)] = {'statistics': each_statistics, 'img': base64_str}
if 'error' not in each_statistics:
save_index(each, base64_str, each_statistics)
else:
each_statistics = info['statistics']
base64_str = info['img']
ret[os.path.basename(each)] = {'statistics': each_statistics, 'img': base64_str}
pictures.clear()
return resp_utils.success(ret)

YoloV4框架采用pytorch实现。这里先贴一张YoloV4的架构图。图片来自 睿智的目标检测32——TF2搭建YoloV4目标检测平台(tensorflow2)_Bubbliiiing的博客-CSDN博客

这位Bubbliiiing真的让我学到了很多,他在B站也有关于YoloV4的教程,看了他的视频,我觉得受益匪浅。Bubbliiiing的个人空间_哔哩哔哩_bilibili

9nwxC.png

具体的介绍可以参考这里

YOLOV4 pytorch实现流程 - 知乎 (zhihu.com)

首先是实现CSPdarknet,定义MISH激活函数,以及实现基本的卷积操作,最后再实现一个完整的CSPdarknet。

这里不得不说pytorch实在是太方便了,只要继承nn.Module,然后写forward就可以了。之前TensorFlow折腾了好久,不同的backend,不同的版本居然都不兼容,坑太多了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


# MISH激活函数
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()

def forward(self, x):
return x * torch.tanh(F.softplus(x))


# Conv2d + BatchNormalization + Mish
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(BasicConv, self).__init__()

self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size // 2, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = Mish()

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
return x


# CSPdarknet
class Resblock(nn.Module):
def __init__(self, channels, hidden_channels=None):
super(Resblock, self).__init__()

if hidden_channels is None:
hidden_channels = channels

self.block = nn.Sequential(
BasicConv(channels, hidden_channels, 1),
BasicConv(hidden_channels, channels, 3)
)

def forward(self, x):
return x + self.block(x)


# CSPdarknet的结构块
class Resblock_body(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, first):
super(Resblock_body, self).__init__()
# ----------------------------------------------------------------#
# 利用一个步长为2x2的卷积块进行高和宽的压缩
# ----------------------------------------------------------------#
self.downsample_conv = BasicConv(in_channels, out_channels, 3, stride=2)

if first:
# --------------------------------------------------------------------------#
# 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
# --------------------------------------------------------------------------#
self.split_conv0 = BasicConv(out_channels, out_channels, 1)

# ----------------------------------------------------------------#
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
# ----------------------------------------------------------------#
self.split_conv1 = BasicConv(out_channels, out_channels, 1)
self.blocks_conv = nn.Sequential(
Resblock(channels=out_channels, hidden_channels=out_channels // 2),
BasicConv(out_channels, out_channels, 1)
)

self.concat_conv = BasicConv(out_channels * 2, out_channels, 1)
else:
# --------------------------------------------------------------------------#
# 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
# --------------------------------------------------------------------------#
self.split_conv0 = BasicConv(out_channels, out_channels // 2, 1)

# ----------------------------------------------------------------#
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
# ----------------------------------------------------------------#
self.split_conv1 = BasicConv(out_channels, out_channels // 2, 1)
self.blocks_conv = nn.Sequential(
*[Resblock(out_channels // 2) for _ in range(num_blocks)],
BasicConv(out_channels // 2, out_channels // 2, 1)
)

self.concat_conv = BasicConv(out_channels, out_channels, 1)

def forward(self, x):
x = self.downsample_conv(x)

x0 = self.split_conv0(x)

x1 = self.split_conv1(x)
x1 = self.blocks_conv(x1)

# ------------------------------------#
# 将大残差边再堆叠回来
# ------------------------------------#
x = torch.cat([x1, x0], dim=1)
# ------------------------------------#
# 最后对通道数进行整合
# ------------------------------------#
x = self.concat_conv(x)

return x


# CSPdarknet53
class CSPDarkNet(nn.Module):
def __init__(self, layers):
super(CSPDarkNet, self).__init__()
self.inplanes = 32
# 416,416,3 -> 416,416,32
self.conv1 = BasicConv(3, self.inplanes, kernel_size=3, stride=1)
self.feature_channels = [64, 128, 256, 512, 1024]

self.stages = nn.ModuleList([
# 416,416,32 -> 208,208,64
Resblock_body(self.inplanes, self.feature_channels[0], layers[0], first=True),
# 208,208,64 -> 104,104,128
Resblock_body(self.feature_channels[0], self.feature_channels[1], layers[1], first=False),
# 104,104,128 -> 52,52,256
Resblock_body(self.feature_channels[1], self.feature_channels[2], layers[2], first=False),
# 52,52,256 -> 26,26,512
Resblock_body(self.feature_channels[2], self.feature_channels[3], layers[3], first=False),
# 26,26,512 -> 13,13,1024
Resblock_body(self.feature_channels[3], self.feature_channels[4], layers[4], first=False)
])

self.num_features = 1
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x):
x = self.conv1(x)

x = self.stages[0](x)
x = self.stages[1](x)
out3 = self.stages[2](x)
out4 = self.stages[3](out3)
out5 = self.stages[4](out4)

return out3, out4, out5


def darknet53(pretrained):
model = CSPDarkNet([1, 2, 8, 8, 4])
if pretrained:
if isinstance(pretrained, str):
model.load_state_dict(torch.load(pretrained))
else:
raise Exception("darknet request a pretrained path. got [{}]".format(pretrained))
return model

然后再实现yolov4的框架,包括上图中的SPP和PANet,最后输出三个特征层Yolo head。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from collections import OrderedDict

import torch
import torch.nn as nn

from nets.CSPdarknet import darknet53


def conv2d(filter_in, filter_out, kernel_size, stride=1):
pad = (kernel_size - 1) // 2 if kernel_size else 0
return nn.Sequential(OrderedDict([
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
("bn", nn.BatchNorm2d(filter_out)),
("relu", nn.LeakyReLU(0.1)),
]))


# SPP结构
class SpatialPyramidPooling(nn.Module):
def __init__(self, pool_sizes=[5, 9, 13]):
super(SpatialPyramidPooling, self).__init__()

self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size // 2) for pool_size in pool_sizes])

def forward(self, x):
features = [maxpool(x) for maxpool in self.maxpools[::-1]]
features = torch.cat(features + [x], dim=1)

return features


# 卷积 + 上采样
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()

self.upsample = nn.Sequential(
conv2d(in_channels, out_channels, 1),
nn.Upsample(scale_factor=2, mode='nearest')
)

def forward(self, x, ):
x = self.upsample(x)
return x


# 三次卷积块
def make_three_conv(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
)
return m


# 五次卷积块
def make_five_conv(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
)
return m


# yolov4的输出
def yolo_head(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 3),
nn.Conv2d(filters_list[0], filters_list[1], 1),
)
return m


# yolo_body
class YoloBody(nn.Module):
def __init__(self, num_anchors, num_classes):
super(YoloBody, self).__init__()
# ---------------------------------------------------#
# 生成CSPdarknet53的主干模型
# 获得三个有效特征层,他们的shape分别是:
# 52,52,256
# 26,26,512
# 13,13,1024
# ---------------------------------------------------#
self.backbone = darknet53(None)

self.conv1 = make_three_conv([512, 1024], 1024)
self.SPP = SpatialPyramidPooling()
self.conv2 = make_three_conv([512, 1024], 2048)

self.upsample1 = Upsample(512, 256)
self.conv_for_P4 = conv2d(512, 256, 1)
self.make_five_conv1 = make_five_conv([256, 512], 512)

self.upsample2 = Upsample(256, 128)
self.conv_for_P3 = conv2d(256, 128, 1)
self.make_five_conv2 = make_five_conv([128, 256], 256)

# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
final_out_filter2 = num_anchors * (5 + num_classes)
self.yolo_head3 = yolo_head([256, final_out_filter2], 128)

self.down_sample1 = conv2d(128, 256, 3, stride=2)
self.make_five_conv3 = make_five_conv([256, 512], 512)

# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
final_out_filter1 = num_anchors * (5 + num_classes)
self.yolo_head2 = yolo_head([512, final_out_filter1], 256)

self.down_sample2 = conv2d(256, 512, 3, stride=2)
self.make_five_conv4 = make_five_conv([512, 1024], 1024)

# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
final_out_filter0 = num_anchors * (5 + num_classes)
self.yolo_head1 = yolo_head([1024, final_out_filter0], 512)

def forward(self, x):
# backbone
x2, x1, x0 = self.backbone(x)

# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048
P5 = self.conv1(x0)
P5 = self.SPP(P5)
# 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
P5 = self.conv2(P5)

# 13,13,512 -> 13,13,256 -> 26,26,256
P5_upsample = self.upsample1(P5)
# 26,26,512 -> 26,26,256
P4 = self.conv_for_P4(x1)
# 26,26,256 + 26,26,256 -> 26,26,512
P4 = torch.cat([P4, P5_upsample], axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
P4 = self.make_five_conv1(P4)

# 26,26,256 -> 26,26,128 -> 52,52,128
P4_upsample = self.upsample2(P4)
# 52,52,256 -> 52,52,128
P3 = self.conv_for_P3(x2)
# 52,52,128 + 52,52,128 -> 52,52,256
P3 = torch.cat([P3, P4_upsample], axis=1)
# 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128
P3 = self.make_five_conv2(P3)

# 52,52,128 -> 26,26,256
P3_downsample = self.down_sample1(P3)
# 26,26,256 + 26,26,256 -> 26,26,512
P4 = torch.cat([P3_downsample, P4], axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
P4 = self.make_five_conv3(P4)

# 26,26,256 -> 13,13,512
P4_downsample = self.down_sample2(P4)
# 13,13,512 + 13,13,512 -> 13,13,1024
P5 = torch.cat([P4_downsample, P5], axis=1)
# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512
P5 = self.make_five_conv4(P5)

# ---------------------------------------------------#
# 第三个特征层
# y3=(batch_size,75,52,52)
# ---------------------------------------------------#
out2 = self.yolo_head3(P3)
# ---------------------------------------------------#
# 第二个特征层
# y2=(batch_size,75,26,26)
# ---------------------------------------------------#
out1 = self.yolo_head2(P4)
# ---------------------------------------------------#
# 第一个特征层
# y1=(batch_size,75,13,13)
# ---------------------------------------------------#
out0 = self.yolo_head1(P5)

return out0, out1, out2

搭建完成之后就是训练了。

训练采用的是实验室的TITAN RTX,这里我们小组总共找了将近1900张图片,然后手动打标签,不得不说那一周看虫子都快看吐了。

91pvg.jpg

91e7I.jpg

最后通过脚本扫描所有的图片,将对应的xml标记文件对应。

91rLD.jpg

然后开始训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def start_train(queue=None):
global model, __iterationCount
__iterationCount = 1

# classes和anchor的路径
anchors_path = 'model_data/yolo_anchors.txt'
classes_path = 'model_data/all_classes.txt'
# 获取classes和anchor
class_names = get_classes(classes_path)
anchors = get_anchors(anchors_path)
num_classes = len(class_names)

# 创建yolo模型
model = YoloBody(len(anchors[0]), num_classes)

model_path = "model_data/yolo4_weights.pth"
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if os.path.exists(model_path):
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print('Finished!')
else:
print("error no pretrained model")

net = model.train()

if Cuda:
net = torch.nn.DataParallel(model)
cudnn.benchmark = True
net = net.cuda()

# 建立loss函数
yolo_losses = []
for i in range(3):
yolo_losses.append(YOLOLoss(np.reshape(anchors, [-1, 2]), num_classes,
(input_shape[1], input_shape[0]), smooth_label, Cuda, normalize))

# 获得图片路径和标签
annotation_path = '2007_train.txt'
val_split = 0.1
with open(annotation_path, encoding='utf-8') as f:
lines = f.readlines()
np.random.seed(10101)
np.random.shuffle(lines)
np.random.seed(None)
num_val = int(len(lines) * val_split)
num_train = len(lines) - num_val

Init_Epoch = 0
Freeze_Epoch = 200
Unfreeze_Epoch = 300

# 冻结训练部分
lr = 1e-3
Batch_size = 4

global optimizer
optimizer = optim.Adam(net.parameters(), lr)
if Cosine_lr:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)

if Use_Data_Loader:
train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic,
is_train=True)
val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False)
gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
else:
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(train=True, mosaic=mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(train=False, mosaic=mosaic)

epoch_size = max(1, num_train // Batch_size)
epoch_size_val = num_val // Batch_size
# 冻结训练
for param in model.backbone.parameters():
param.requires_grad = False

for epoch in range(Init_Epoch, Freeze_Epoch):
fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, gen_val, Freeze_Epoch, Cuda,
queue=queue)
lr_scheduler.step()

# 解冻训练
lr = 1e-4
Batch_size = 2

optimizer = optim.Adam(net.parameters(), lr)
if Cosine_lr:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)

if Use_Data_Loader:
train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic,
is_train=True)
val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False)
gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
else:
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(train=True, mosaic=mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(train=False, mosaic=mosaic)

epoch_size = max(1, num_train // Batch_size)
epoch_size_val = num_val // Batch_size
# 解冻后训练
for param in model.backbone.parameters():
param.requires_grad = True

for epoch in range(Freeze_Epoch, Unfreeze_Epoch):
fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, gen_val, Unfreeze_Epoch, Cuda,
queue=queue)
lr_scheduler.step()

最后训练出来的效果还是很不错的。

91y9b.png

91vPP.png

YoloV5网络

前面说到了由于YoloV4过于庞大,无法放到安卓app中,于是我们小组还开发了YoloV5的版本。

iOyPCc.png

模型也是使用pytorch实现。这里就不再赘述。

训练过程基本和上面差不多,使用相同的训练集,训练完成后获取模型即可。

比较值得记录的是如何放到安卓app上。

Android APP

安卓客户端基础功能采用vue-cli开发,制作为h5小程序,由于pytorch在安卓端必须采用Android studio引入,所以不能直接使用vue打包成apk。

因此我们小组采用capacitorJS将vue app项目转换成原生代码项目,再使用pytorch原生库加载离线识别模型文件,使用NanoHTTPD实现在本地模拟在线api,采用sqlite和okhttpClient定期同步远程数据库。

在安卓客户端拍图片之后,将会先被编码为base64,然后再上传至YoloV4网络,获取结果之后再显示。

在build.gradle中引入pytorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
dependencies {
implementation fileTree(include: ['*.jar'], dir: 'libs')
implementation "androidx.appcompat:appcompat:$androidxAppCompatVersion"
implementation project(':capacitor-android')
testImplementation "junit:junit:$junitVersion"
androidTestImplementation "androidx.test.ext:junit:$androidxJunitVersion"
androidTestImplementation "androidx.test.espresso:espresso-core:$androidxEspressoCoreVersion"
implementation project(':capacitor-cordova-android-plugins')
implementation'org.nanohttpd:nanohttpd:2.3.1'
implementation'com.alibaba:fastjson:1.1.72.android'
implementation'com.squareup.okhttp3:okhttp:3.10.0'
implementation 'org.pytorch:pytorch_android:1.8.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0'
}

然后把训练好的YoloV5模型放入assets文件夹中,在app打开后进行加载。

1
2
3
4
5
6
7
8
/**
* 加载网络模型
* @param assetManager assets读取工具
* @param moduleName Torchscript文件名
*/
public void init(AssetManager assetManager, String moduleName) {
module = PyTorchAndroid.loadModuleFromAsset(assetManager, moduleName);
}

然后实现predict函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
* 输入图片,输出预测信息
* @param bitmap 输入图片
* @param processedImg 处理后的图片,Object[1] 需要放到数组中,否则局部变量无法获取
* @return 预测信息
*/
public Map<String, Integer> predict(Bitmap bitmap, Object[] processedImg) {
HashMap<String, Integer> resultMap = new HashMap<>();
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
new float[]{0.0f, 0.0f, 0.0f}, new float[]{1.0f, 1.0f, 1.0f});

// running the model
IValue value = IValue.from(inputTensor);
final IValue[] output = module.forward(value).toTuple();
final Tensor outputTensor = output[0].toTensor();
final float[] outputs = outputTensor.getDataAsFloatArray();
float mImgScaleX = bitmap.getWidth() / imageSize;
float mImgScaleY = bitmap.getHeight() / imageSize;
float mIvScaleX = (bitmap.getWidth() > bitmap.getHeight() ? imageSize / bitmap.getWidth() : imageSize / bitmap.getHeight());
float mIvScaleY = (bitmap.getHeight() > bitmap.getWidth() ? imageSize / bitmap.getHeight() : imageSize / bitmap.getWidth());
final ArrayList<Result> results = outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY);
boolean hasRect = false;
for (Result each : results) {
String code = Classes.classes[each.classIndex];
if (resultMap.containsKey(code)) {
Integer num = resultMap.get(code);
num += 1;
resultMap.put(code, num);
} else {
resultMap.put(code, 1);
}
if (each.rect.top != each.rect.bottom && each.rect.left != each.rect.right) {
bitmap = drawRectangles(bitmap, each.rect);
processedImg[0] = bitmap;
hasRect = true;
}
if (!hasRect) {
processedImg[0] = bitmap;
}
}
if (resultMap.isEmpty()) {
resultMap.put("error", 1);
}
System.out.println(resultMap);
return resultMap;
}

以及在识别后的图片上画框的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/**
* 在bitmap上画框
* @param imageBitmap 输入图片
* @param valueRects 框的位置
* @return 画了框的图片
*/
private Bitmap drawRectangles(Bitmap imageBitmap, Rect valueRects) {
Bitmap mutableBitmap = imageBitmap.copy(Bitmap.Config.ARGB_8888, true);
Canvas canvas = new Canvas(mutableBitmap);
Paint paint = new Paint();
for (int i = 0; i < 8; i++) {
paint.setColor(Color.RED);
paint.setStyle(Paint.Style.STROKE);//不填充
paint.setStrokeWidth(10); //线的宽度
canvas.drawRect(valueRects.left, valueRects.top, valueRects.right, valueRects.bottom, paint);
}
return mutableBitmap;
}

以及nms函数,和非极大抑制函数,计算IOU的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/**
* 对原始输出进行处理
* @param outputs 输出
* @param imgScaleX 相对图片大小X
* @param imgScaleY 相对图片大小Y
* @param ivScaleX 显示最小图片大小X
* @param ivScaleY 显示最小图片大小Y
* @return Result list
*/
private static ArrayList<Result> outputsToNMSPredictions(float[] outputs, float imgScaleX, float imgScaleY, float ivScaleX, float ivScaleY) {
ArrayList<Result> results = new ArrayList<>();
int mOutputRow = 25200;
int mOutputColumn = Classes.classes.length + 5;
for (int i = 0; i < mOutputRow; i++) {
try {
if (outputs[i * mOutputColumn + 4] > mThreshold) {
float x = outputs[i * mOutputColumn];
float y = outputs[i * mOutputColumn + 1];
float w = outputs[i * mOutputColumn + 2];
float h = outputs[i * mOutputColumn + 3];

float left = imgScaleX * (x - w / 2);
float top = imgScaleY * (y - h / 2);
float right = imgScaleX * (x + w / 2);
float bottom = imgScaleY * (y + h / 2);

float max = outputs[i * mOutputColumn + 5];
int cls = 0;
for (int j = 0; j < mOutputColumn - 5; j++) {
if (outputs[i * mOutputColumn + 5 + j] > max) {
max = outputs[i * mOutputColumn + 5 + j];
cls = j;
}
}

Rect rect = new Rect((int) ((float) 0 + ivScaleX * left), (int) ((float) 0 + top * ivScaleY), (int) ((float) 0 + ivScaleX * right), (int) ((float) 0 + ivScaleY * bottom));
Result result = new Result(cls, outputs[i * mOutputColumn + 4], rect);
results.add(result);
}
} catch (ArrayIndexOutOfBoundsException e) {
e.printStackTrace();
}
}
return nonMaxSuppression(results);
}

/**
* 非极大抑制
* @param boxes 结果
* @return 筛选后的结果
*/
private static ArrayList<Result> nonMaxSuppression(ArrayList<Result> boxes) {

Collections.sort(boxes,
(o1, o2) -> o1.score.compareTo(o2.score));

ArrayList<Result> selected = new ArrayList<>();
boolean[] active = new boolean[boxes.size()];
Arrays.fill(active, true);
int numActive = active.length;

boolean done = false;
for (int i = 0; i < boxes.size() && !done; i++) {
if (active[i]) {
Result boxA = boxes.get(i);
selected.add(boxA);
if (selected.size() >= CnnApi.mNmsLimit) break;

for (int j = i + 1; j < boxes.size(); j++) {
if (active[j]) {
Result boxB = boxes.get(j);
if (IOU(boxA.rect, boxB.rect) > CnnApi.mThreshold) {
active[j] = false;
numActive -= 1;
if (numActive <= 0) {
done = true;
break;
}
}
}
}
}
}
return selected;
}

private static float IOU(Rect a, Rect b) {
float areaA = (a.right - a.left) * (a.bottom - a.top);
if (areaA <= 0.0) return 0.0f;

float areaB = (b.right - b.left) * (b.bottom - b.top);
if (areaB <= 0.0) return 0.0f;

float intersectionMinX = Math.max(a.left, b.left);
float intersectionMinY = Math.max(a.top, b.top);
float intersectionMaxX = Math.min(a.right, b.right);
float intersectionMaxY = Math.min(a.bottom, b.bottom);
float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0) *
Math.max(intersectionMaxX - intersectionMinX, 0);
return intersectionArea / (areaA + areaB - intersectionArea);
}

static class Result {
int classIndex;
Float score;
Rect rect;

public Result(int cls, Float output, Rect rect) {
this.classIndex = cls;
this.score = output;
this.rect = rect;
}
}

然后将这个离线识别方法封装成API。

采用NanoHTTPD对外提供接口。这样在无网络的环境下只需要把服务器的ip地址更改为127.0.0.1,即可请求到本地的YoloV5识别网络,并实现离线识别。

最后放几张图片

91tNK.jpg

914Ea.jpg