0%

为了优化Java的性能 ,JVM在解释器之外引入了即时(Just In Time)编译器:当程序运行时,解释器首先发挥作用,代码可以直接执行。当程序运行时,JIT编译器再编译代码,获取更高的运行效率。

这篇文章来记录下JIT编译器的工作过程以及一些优化的把戏。本文章中大部分的知识都是来自这个视频中的内容,我学习之后将他们记录下来。

理解 JVM 中 JIT 玩的把戏_1_哔哩哔哩_bilibili

JIT分层编译

Java 8引入了分层编译的概念,分层编译将JVM的执行状态分为了五个层次。五个层级分别是:

  1. 解释执行。

  2. 执行不带profiling的C1代码。

  3. 执行仅带方法调用次数以及循环回边执行次数profiling的C1代码。

  4. 执行带所有profiling的C1代码。

  5. 执行C2代码。

第一阶段就是解释执行,第二阶段是C1编译,C1拥有三个模式,也就是2,3,4,最后是C2编译。

C1第一个模式,仅仅是编译执行,代码中没有任何埋点进行profiling的地方。

C1第二个模式,是编译加上一个counter计数器,当计数器达到阈值将会到达C2。

C1第三个模式,是编译加上所有的profiling,进行详细分析。比如某个分支,永远是true,或者某一个接口方法的调用,只使用过一种类型(比如List的add方法,实际上永远是ArrayList),或者某个强制转换从未失败,某个调用从来没有产生过NEP异常。

在一般情况下,代码从解释执行到C1,会直接到C1的第三个模式,最后达到C2。

在C2忙的时候,也就是说C2拥有很多方法需要编译,队列中有很多待处理的任务时,代码会先达到C1的第二个模式,再达到C1的第三个模式,以此来减少C2的执行时间。在C1第三个模式下,如果profiling没有收集到有价值的数据,那么jvm会断定C2编译并没有比C1好很多。因此将会直接到C1的第一个模式。

在C1忙的时候也有可能直接达到C2,在解释过程中进行profiling,然后直接由C2编译。

Uncommon Trap

比如ArrayStream的forEach方法。

1
2
3
4
5
ArrayStream<E>.forEach(Consumer<? super E> action) {
for (int i = 0; i < this.elementData.length(), i++) {
action.accept(this.elementData[i]);
}
}

首先,jvm为了保证代码的安全性,必须在代码中添加安全检查,因此,代码变成了下面这样。

1
2
3
4
5
6
7
8
9
10
11
12
ArrayStream<E>.forEach(Consumer<? super E> action) {
if (this == null) throw new NPE();
if (this.elementData == null) throw new NPE();
for (int i = 0; i < this.elementData.length(), i++) {
if (this == null) throw new NPE();
if (this.elementData == null) throw new NPE();
if (i < 0) throw new AIOBE();
if (i >= this.elementData.length) throw new AIOBE();
if (action == null) throw new NPE();
action.accept(this.elementData[i]);
}
}

首先可以去掉的就是对this的检查,因为在对象的方法体中,this不可能为null。

其次,由于i从0开始,并且this.elementData.length()不可能大于Integer.MAX,且i递增每次+1,所以,i不可能小于0。因此 i < 0被去掉了。

接着,对于 i >= this.elementData.length,编译器将它优化为了 !(i < this.elementData.length) ,这样,在上面的for循环中已经判断了一次 i <this.elementData.length,在这个检查中只需要将上面计算的结果进行取反。也就是公共表达式消除。

到这里,我们只剩下了,最开始的elementData == null的检查,循环中的elementData == null的检查,以及循环中的action == null的检查。

这三个null检查是无法消除的,因此对于无法消除的null检查,我们就需要将检查的开销将到最低。

因此,在编译时,检查elementData == null将会被直接消除,然后jvm会注册一个sig_fault handler,当出现了elementData为null的时候,jvm不会直接崩溃,而是会进入trap,通过操作系统调用段错误处理器,然后再抛出一个NPE异常。

这样的优化比起直接新建一个分支来判断会快很多。前提是,这个elementData并不是null。类似于jvm断言这里肯定不会产生NPE异常,但是一旦产生null,就会陷入trap,然后去优化,整体就会变慢。

循环剥离

在上面的null判断中,除了使用trap,还有一种方法,也就是将第一次循环剥离出来。

这时,代码变成了下面这样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
ArrayStream.forEach(Consumer<? super E> action) {
if (this.elementData == null) {
throw new NPE();
}
if (0 < this.elementData.length()) {
if (this.elementData == null) {
throw new NPE();
}
if (action == null) throw new NPE();
action.accept(this.elementData[0]);
}

for (int i = 1; i < this.elementData.length(); i++) {
if (this.elementData == null) throw new NPE();
action.accept(this.elementData[i]);
}
}

这样我们让action的null检查从n次变成了一次。

方法内联

方法内联很简单,本质就是将被调用的方法的代码复制到调用者的方法体里,这样就消除了一次调用。

比如遍历elementData,将每一个元素传入一个consume方法,该方法将传入的每一个数字加到sum上。

1
2
3
static void consume(int x) {
sum += x;
}

采用for循环的话有三种方式。

第一种就是简单的获取this.elementData并遍历

1
2
3
for (int i = 0; i < this.elementData.length(); i++) {
consume(this.elementData[i]);
}

第二种方法和第一种有一点不同,我们将this.elementData用一个局部变量保存起来。

1
2
3
4
E[] elementData = this.elementData;
for (int i = 0; i < elementData.length(); i++) {
consume(elementData[i]);
}

第三种方式是使用Java的增强for循环

1
2
3
for (int i : this.elementData) {
consume(elementData[x]);
}

在consume方法内联的情况下,三个函数的执行效率都差不多。

但是,在不内联的情况下,第一种方式就比第二和第三种慢了许多。

原因在于,第一种在每一次循环都需要重新加载elementData,因为elementData可能被修改。

第二种方法使用了局部变量,因为局部变量是无法被修改的,所以不需要被重新加载。

第三种方法其实和第二种本质上没有什么差别,增强for循环也会创建一个局部变量。

但是对于虚函数,由于会有不同的实现类,那么又该如何内联呢。

去虚拟化

在Java中,除了private函数以及final修饰的函数,几乎所有的对象函数都是虚函数。虚函数是实现多态的基础。

每次调用虚函数,我们都要去找到对应的真实类型的函数,并进行调用。那么要怎么样才能去除虚函数的开销呢。

首先是静态分析。

假设有这样一个函数apply,接受一个Function参数,和一个int类型的参数,然后在函数内调用func.apply(x)

1
2
3
static double apply(Function func, int x) {
func.apply(x);
}

然后在main函数中调用apply

1
2
3
4
5
6
7
8
public class Monomorphic {
public static void main(String[] args) {
Function func = new Square();
for (int i = 0; i < 20_000; ++i) {
apply(func, i);
}
}
}

然后我们查看编译的日志

iHpJ23.jpeg

就会发现,编译器将先将apply函数内联进了main函数,然后直接将Square的apply内联进了Monomorphic的apply函数。

这里JIT使用了一个封闭世界的假设,它发现现在整个jvm中只有一个Function的实例,也就是Square,其他的实现了Function接口的类可能存在,但是现在还没有被加载。

如果我们在这时加载了一个其他的类,比如Sqrt。这时实现了Function接口的类就不只有Square了,还存在Sqrt。这时,上面的假设就不成立了。

因此这时,前面的编译将会失效,因为它直接内联了Square的apply函数。jvm将会立刻去优化。

在这时,C2编译器将会被锁住,所有的线程将会返回到解释模式,性能将会大幅下降。

这时,虚拟机将会开始收集调用的信息,被称为类型采样分析,这些信息由C1或者解释器收集。

虚拟机将会收集每个Function的实现类被调用了几次,从哪里调用。

假如在20000次Function.apply调用中,有19000次都是Square.apply,只有1000次是Sqrt.apply。这时JIT就可以基于这些信息开始编译。

这次编译并不会像之前那样直接内联Square和Sqrt的apply函数。而是会加上类型检查。

这时,代码将会在变成类似下方这样。

1
2
3
4
5
6
7
if (func.getClass().equals(Square.class)) {
// 内联Square的apply函数
} else if (func.getClass().equals(Sqrt.class)) {
// 内联Sqrt的apply函数
} else {
uncommon_trap()
}

对于为什么不在else后面直接放虚函数调用,而还是加上一个uncommon trap。这时由于如果加上虚函数调用,那么在那个虚函数中又有可能会出现修改变量的行为,导致之前编译的代码出现问题。虚函数调用可以使用反射,反射可以修改几乎任何东西。

总结

这里只写了JIT优化的很少一部分,还有很多其他的东西都没有写,比如GC是如何实现Stop The World的,GC如何与编译后的函数交互等等。

从这里可以看出,Java很喜欢小方法和局部变量,很喜欢不变的东西,常量。Java也并不喜欢原生方法,因为Java虚拟机无法知晓native函数中具体做了什么,也无法优化,并且有可能导致崩溃。除非是intrinsic函数,也就是在jvm中有特殊实现的函数。

因此,写Java代码应该保持函数尽可能的小,当函数字节码大小超过8000字节之后,JIT就不会对该函数进行优化。JIT将会将各个小函数内联起来并进行优化。不应该一直抛出异常,因为异常可能会导致编译优化失效。变量应该尽量使用局部变量,因为局部变量的不变性使得他们更容易被优化。

通过USN日志进行文件监控

在以前的文章中有一篇是自己实现一个Everything,其中讲了通过readDirectChanges函数进行文件监控并同步的方法。但是这样的方法在监控整个磁盘时好像会漏掉一些文件。

下面介绍另一种方法,通过读取USN日志来进行文件的监控。

代码已经开源到GitHub,之前的ReadDirectoryChanges API的版本也有保存。

File-Engine/C++/fileMonitor at master · XUANXUQAQ/File-Engine (github.com)

File-Engine/C++/fileMonitorReadDirChanges at master · XUANXUQAQ/File-Engine (github.com)

代码以及资料参考自

windows - USN NFTS change notification event interrupt - Stack Overflow

c++ - How can I detect only deleted, changed, and created files on a volume? - Stack Overflow

原理

Obtaining Directory Change Notifications - Win32 apps | Microsoft Learn

在微软官网这篇文章中,详细写了如何获取文件夹的变化通知。

Change Journals - Win32 apps | Microsoft Learn

Keeping an Eye on Your NTFS Drives: the Windows 2000 Change Journal Explained | Microsoft Learn

这里详细介绍了NTFS的usn日志是什么,以及usn日志的数据结构等。

简单来说,每当一个文件进行变动,都会写入usn日志。我们可以通过监控是否有新的usn日志记录写入来判断是否有文件更改,并进行监控。

实现

定义监控类

首先定义一个NTFSChangesWatcher类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#pragma once
#include <memory>
#include <string>
#include <Windows.h>
class NTFSChangesWatcher
{
public:
NTFSChangesWatcher(char drive_letter);
~NTFSChangesWatcher() = default;

// Method which runs an infinite loop and waits for new update sequence number in a journal.
// The thread is blocked till the new USN record created in the journal.
void WatchChanges(const bool* flag, void(*)(const std::u16string&), void(*)(const std::u16string&));

private:
HANDLE OpenVolume(char drive_letter);

bool CreateJournal(HANDLE volume);

bool LoadJournal(HANDLE volume, USN_JOURNAL_DATA* journal_data);

bool WaitForNextUsn(PREAD_USN_JOURNAL_DATA read_journal_data) const;

std::unique_ptr<READ_USN_JOURNAL_DATA> GetWaitForNextUsnQuery(USN start_usn);

bool ReadJournalRecords(PREAD_USN_JOURNAL_DATA journal_query, LPVOID buffer,
DWORD& byte_count) const;

USN ReadChangesAndNotify(USN low_usn, char* buffer, void(*)(const std::u16string&), void(*)(const std::u16string&));

std::unique_ptr<READ_USN_JOURNAL_DATA> GetReadJournalQuery(USN low_usn);

void showRecord(std::u16string& full_path, USN_RECORD* record);

char drive_letter_;

HANDLE volume_;

std::unique_ptr<USN_JOURNAL_DATA> journal_;

DWORDLONG journal_id_;

USN last_usn_;

USN max_usn_;

// Flags, which indicate which types of changes you want to listen.
static const int FILE_CHANGE_BITMASK;

static const int kBufferSize;
};

对外的接口函数为WatchChanges

1
void WatchChanges(const bool* flag, void(*)(const std::u16string&), void(*)(const std::u16string&));

函数有三个参数,第一个为停止监控文件标志,当设置为false将会退出循环。第二个参数为当新增文件时的回调函数指针,第三个参数为删除文件时的回调函数指针。

初始化USN日志

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
const int NTFSChangesWatcher::kBufferSize = 1024 * 1024 / 2;

const int NTFSChangesWatcher::FILE_CHANGE_BITMASK = USN_REASON_RENAME_NEW_NAME | USN_REASON_RENAME_OLD_NAME;

NTFSChangesWatcher::NTFSChangesWatcher(char drive_letter) :
drive_letter_(drive_letter)
{
volume_ = OpenVolume(drive_letter_);

journal_ = std::make_unique<USN_JOURNAL_DATA>();

if (const bool res = LoadJournal(volume_, journal_.get()); !res) {
fprintf(stderr, "Failed to load journal");
return;
}
max_usn_ = journal_->MaxUsn;
journal_id_ = journal_->UsnJournalID;
last_usn_ = journal_->NextUsn;
}

首先通过OpenVolume打开磁盘,并返回一个HANDLE,然后分配存储日志的内存空间,接着通过LoadJournal读取usn日志。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
HANDLE NTFSChangesWatcher::OpenVolume(const char drive_letter)
{

wchar_t pattern[10] = L"\\\\?\\a:";

pattern[4] = static_cast<wchar_t>(drive_letter);

const HANDLE volume = CreateFile(
pattern, // lpFileName
// also could be | FILE_READ_DATA | FILE_READ_ATTRIBUTES | SYNCHRONIZE
GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE, // dwDesiredAccess
FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, // share mode
nullptr, // default security attributes
OPEN_EXISTING, // disposition
// It is always set, no matter whether you explicitly specify it or not. This means, that access
// must be aligned with sector size so we can only read a number of bytes that is a multiple of the sector size.
FILE_FLAG_NO_BUFFERING, // file attributes
nullptr // do not copy file attributes
);

if (volume == INVALID_HANDLE_VALUE) {
// An error occurred!
fprintf(stderr, "Failed to open volume");
return nullptr;
}

return volume;
}

获取HANDLE后,通过LoadJournal获取USN日志,第一次读取失败将会尝试创建后再次尝试读取。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
bool NTFSChangesWatcher::LoadJournal(HANDLE volume, USN_JOURNAL_DATA* journal_data)
{

DWORD byte_count;

// Try to open journal.
if (!DeviceIoControl(volume,
FSCTL_QUERY_USN_JOURNAL,
nullptr,
0,
journal_data,
sizeof(*journal_data),
&byte_count,
nullptr))
{
// If failed (for example, in case journaling is disabled), create journal and retry.

if (CreateJournal(volume)) {
return LoadJournal(volume, journal_data);
}

return false;
}
return true;
}


bool NTFSChangesWatcher::CreateJournal(HANDLE volume)
{

DWORD byte_count;
CREATE_USN_JOURNAL_DATA create_journal_data{};

const bool ok = DeviceIoControl(volume, // handle to volume
FSCTL_CREATE_USN_JOURNAL, // dwIoControlCode
&create_journal_data, // input buffer
sizeof(create_journal_data), // size of input buffer
nullptr, // lpOutBuffer
0, // nOutBufferSize
&byte_count, // number of bytes returned
nullptr) != 0; // OVERLAPPED structure

if (!ok) {
// An error occurred!
}

return ok;
}

开始监控

初始化完成之后就可以调用WatchChanges函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void NTFSChangesWatcher::WatchChanges(const bool* flag,
void(*file_added_callback_func)(const std::u16string&),
void(*file_removed_callback_func)(const std::u16string&))
{
const auto u_buffer = std::make_unique<char[]>(kBufferSize);

const auto read_journal_query = GetWaitForNextUsnQuery(last_usn_);

while (*flag)
{
// This function does not return until new USN record created.
WaitForNextUsn(read_journal_query.get());
last_usn_ = ReadChangesAndNotify(read_journal_query->StartUsn,
u_buffer.get(),
file_added_callback_func,
file_removed_callback_func);
read_journal_query->StartUsn = last_usn_;
}
delete flag;
}

核心的方法就两个,一个WaitForNextUsn,一个ReadChangesAndNotify

首先来看WaitForNextUsn

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
bool NTFSChangesWatcher::WaitForNextUsn(PREAD_USN_JOURNAL_DATA read_journal_data) const
{

DWORD bytes_read;

// This function does not return until new USN record created.
const bool ok = DeviceIoControl(volume_,
FSCTL_READ_USN_JOURNAL,
read_journal_data,
sizeof(*read_journal_data),
&read_journal_data->StartUsn,
sizeof(read_journal_data->StartUsn),
&bytes_read,
nullptr) != 0;
return ok;
}

通过DeviceIoControl函数,发送FSCTL_READ_USN_JOURNAL事件,由于我们之前初始化的时候设置了从最后一个usn记录开始读取,这时该方法将会阻塞直到用户进行操作,NTFS写入一个新的USN日志。

这里的最后一个参数lpOverlapped必须为NULL,因为我们要监控文件的变化,需要阻塞函数,如果是异步调用反而会有各种各样的不方便。

关于DeviceIoControl函数网上已经有很多解释,这里就放个msdn吧。

DeviceIoControl function (ioapiset.h) - Win32 apps | Microsoft Learn

以及FSCTL_READ_USN_JOURNAL

FSCTL_READ_USN_JOURNAL - Win32 apps | Microsoft Learn

当该方法返回后,代表磁盘中出现了一个新的usn记录,这时就会执行到下一个函数

ReadChangesAndNotify

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
USN NTFSChangesWatcher::ReadChangesAndNotify(USN low_usn,
char* buffer,
void(*file_added_callback_func)(const std::u16string&),
void(*file_removed_callback_func)(const std::u16string&))
{

DWORD byte_count;

const auto journal_query = GetReadJournalQuery(low_usn);
memset(buffer, 0, kBufferSize);
if (!ReadJournalRecords(journal_query.get(), buffer, byte_count))
{
// An error occurred.
return low_usn;
}

auto record = reinterpret_cast<USN_RECORD*>(reinterpret_cast<USN*>(buffer) + 1);
const auto record_end = reinterpret_cast<USN_RECORD*>(reinterpret_cast<BYTE*>(buffer) + byte_count);

std::u16string full_path;
for (; record < record_end;
record = reinterpret_cast<USN_RECORD*>(reinterpret_cast<BYTE*>(record) + record->RecordLength))
{
const auto reason = record->Reason;
full_path.clear();
// It is really strange, but some system files creating and deleting at the same time.
if ((reason & USN_REASON_FILE_CREATE) && (reason & USN_REASON_FILE_DELETE))
{
continue;
}
if ((reason & USN_REASON_FILE_CREATE) && (reason & USN_REASON_CLOSE))
{
showRecord(full_path, record);
file_added_callback_func(full_path);
}
else if ((reason & USN_REASON_FILE_DELETE) && (reason & USN_REASON_CLOSE))
{
showRecord(full_path, record);
file_removed_callback_func(full_path);
}
else if (reason & FILE_CHANGE_BITMASK)
{
if (reason & USN_REASON_RENAME_OLD_NAME)
{
showRecord(full_path, record);
file_removed_callback_func(full_path);
}
else if (reason & USN_REASON_RENAME_NEW_NAME)
{
showRecord(full_path, record);
file_added_callback_func(full_path);
}
}
}
return *reinterpret_cast<USN*>(buffer);
}

这里ReadJournalRecords将会调用DeviceIoControl函数发送FSCTL_READ_USN_JOURNAL读出新的USN日志记录。

读取完成后,通过获取USN_RECORD中的reason字段,得到文件是创建,还是被删除。其实还有很多其他的USN_REASON,不过这里由于只需要检测文件变化,因此只监听了

  • USN_REASON_FILE_CREATE

  • USN_REASON_FILE_DELETE

  • USN_REASON_RENAME_OLD_NAME

  • USN_REASON_RENAME_NEW_NAME

所有的原因可以参考这里

USN_RECORD_V2 - Win32 apps | Microsoft Learn

获取文件完整路径

由于USN日志中记录的只有文件名和文件参照号,因此我们需要通过文件参照号和父文件参照号不断向上查询,拼接出完整的路径。

也就是上面的showRecord函数,该函数有两个参数,full_path,USN_RECORD指针类型的record,也就是需要拼接出完整路径的文件记录。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
void NTFSChangesWatcher::showRecord(std::u16string& full_path, USN_RECORD* record)
{
static std::wstring sep_wstr(L"\\");
static std::u16string sep(sep_wstr.begin(), sep_wstr.end());

const indexer_common::FileInfo file_info(*record, drive_letter_);
if (full_path.empty())
{
full_path += file_info.GetName();
}
else
{
full_path = file_info.GetName() + sep + full_path;
}
DWORD byte_count = 1;
auto buffer = std::make_unique<char[]>(kBufferSize);

MFT_ENUM_DATA_V0 med;
med.StartFileReferenceNumber = record->ParentFileReferenceNumber;
med.LowUsn = 0;
med.HighUsn = max_usn_;

if (!DeviceIoControl(volume_,
FSCTL_ENUM_USN_DATA,
&med,
sizeof(med),
buffer.get(),
kBufferSize,
&byte_count,
nullptr))
{
return;
}

auto* parent_record = reinterpret_cast<USN_RECORD*>(reinterpret_cast<USN*>(buffer.get()) + 1);

if (parent_record->FileReferenceNumber != record->ParentFileReferenceNumber)
{
static std::wstring colon_wstr(L":");
static std::u16string colon(colon_wstr.begin(), colon_wstr.end());
std::string drive;
drive += drive_letter_;
auto&& w_drive = string2wstring(drive);
const std::u16string drive_u16(w_drive.begin(), w_drive.end());
full_path = drive_u16 + colon + sep + full_path;
return;
}
showRecord(full_path, parent_record);
}

首先获得文件名和父文件参照号,然后定义一个MFT_ENUM_DATA,由于MFT_ENUM_DATA_V1会报错Error 87,也就是ERROR_INVALID_PARAMETER,所以这里改成了MFT_ENUM_DATA_V0

System Error Codes (0-499) (WinError.h) - Win32 apps | Microsoft Learn

将开始查询地址设置为record->ParentFileReferenceNumber,并将上界设置为最开始初始化的max_usn_。

然后调用DeviceIoControl,发送FSCTL_ENUM_USN_DATA事件,就可以读取出record的父文件夹的USN记录。

这时,将查询出的父文件夹记录再作为record进行递归查询。

不断向上查询,将文件名拼接到full_path中,最后找到顶层退出递归即可。

获得文件完整路径后,即可调用两个回调函数进行处理了。

前面说到了使用GPU加速文件搜索,但是会占用较多的显存,如果在打游戏,或者跑深度学习的情况下,并不会进行搜索,那么显存就被白白浪费掉了。因此如何实现显存的使用率监控,并实现检测到显存占用过多自动释放搜索缓存就成为了一个问题。

CUDA方面

在NVIDIA的文档中,有一个函数cudaMemGetInfo

1
__host__​cudaError_t cudaMemGetInfo ( size_t* free, size_t* total )

该方法可以直接获取GPU的空闲内存以及总内存。

但是经过实测,返回结果并不准确。因此并没有使用该方法。

OpenCL方面

在OpenCL中没有可以获取GPU占用的函数,但是由于OpenCL可以通过cl_device::getInfo获取显卡的各项参数,因此可以通过OpenCL提供的拓展参数来获取。

比如AMD显卡的OpenCL扩展中就有一项

1
2
3
4
#define CL_DEVICE_GLOBAL_FREE_MEMORY_AMD                0x4039

// 通过该方法即可获取AMD显卡的显存占用
cl_device::getInfo<CL_DEVICE_GLOBAL_FREE_MEMORY_AMD>();

AMD所有的OpenCL扩展在这里

https://registry.khronos.org/OpenCL/extensions/amd/cl_amd_device_attribute_query.txt

不过这样也有一些问题,比如适配困难。现在Intel也推出了自家的独立显卡Arc系列,不同显卡的适配就成为了一个问题。因此该方法也没有采用。

Windows任务管理器

其实在Windows任务管理器中就可以看到显存的占用,因此如何获取这个数值就成为了问题的关键。

l27CC.jpg

在Windows的powershell中有这样的API

Get-Counter (Microsoft.PowerShell.Diagnostics) - PowerShell | Microsoft Learn

通过Get-Counter可以实现获取系统中的各项参数。比如我们这里需要获取GPU相关的信息

l2enL.jpg

在Stack Overflow也找到了相同的解决方案

c++ - How to get the “Dedicated GPU memory” number for every running process in Windows (The same numbers that are shown in the Windows Task Manager) - Stack Overflow

由于powershell调用的效率太低,因此我们需要直接使用对应的Windows API,而不是重复调用powershell脚本。

也就是Windows API中的PDH函数

PDH函数

Using the PDH Functions to Consume Counter Data - Win32 apps | Microsoft Learn

通过PdhOpenQuery函数打开一个查询,然后使用PdhAddCounter设置要查询的信息,最后使用PdhCollectQueryData获取信息,再使用PdhGetRawCounterArray进行转换,即可获取我们想要的显存占用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
std::unordered_map<std::wstring, LONGLONG> query_pdh_val(PDH_STATUS& ret)
{
PDH_HQUERY query;
std::unordered_map<std::wstring, LONGLONG> memory_usage_map;
ret = PdhOpenQuery(nullptr, NULL, &query);
if (ret != ERROR_SUCCESS)
{
return memory_usage_map;
}
PDH_HCOUNTER counter;
ret = PdhAddCounter(query, L"\\GPU Adapter Memory(*)\\Dedicated Usage", NULL, &counter);
if (ret != ERROR_SUCCESS)
{
return memory_usage_map;
}
ret = PdhCollectQueryData(query);
if (ret != ERROR_SUCCESS)
{
return memory_usage_map;
}
DWORD bufferSize = 0;
DWORD itemCount = 0;
PdhGetRawCounterArray(counter, &bufferSize, &itemCount, nullptr);
auto&& lpItemBuffer = reinterpret_cast<PPDH_RAW_COUNTER_ITEM_W>(new char[bufferSize]);
ret = PdhGetRawCounterArray(counter, &bufferSize, &itemCount, lpItemBuffer);
if (ret != ERROR_SUCCESS)
{
delete[] lpItemBuffer;
return memory_usage_map;
}
for (DWORD i = 0; i < itemCount; ++i)
{
auto& [szName, RawValue] = lpItemBuffer[i];
memory_usage_map.insert(std::make_pair(szName, RawValue.FirstValue));
}
delete[] lpItemBuffer;
ret = PdhCloseQuery(query);
return memory_usage_map;
}

这里的query_pdh_val()方法返回的是GPU的luid和显存占用的map。

和powershell返回的结果是相同的

l2t1q.jpg

但是这里又出现了一个大问题,这些luid开头的东西是什么,和显卡是怎么进行对应的。

我们平时看到的显卡名都是像NVIDIA GeForce RTX 2060,AMD Radeon RX 5700XT这样

l2vpx.jpg

powershell中返回的gpu adapter memory(luid_0x00000000_0xXXXXXXXX)到底是如何和显卡一一对应的仍然不清楚。

GPU名称和luid的对应查询方案

在网上找了非常多的资料,都没有找到。最后是在翻看Windows DirectX API中偶然发现了在DXGI_ADAPTER_DESC结构体中有一个字段,名字就叫AdapterLuid

1
2
3
4
5
6
7
8
9
10
11
typedef struct DXGI_ADAPTER_DESC {
WCHAR Description[128];
UINT VendorId;
UINT DeviceId;
UINT SubSysId;
UINT Revision;
SIZE_T DedicatedVideoMemory;
SIZE_T DedicatedSystemMemory;
SIZE_T SharedSystemMemory;
LUID AdapterLuid;
} DXGI_ADAPTER_DESC;

因此,只需要通过DirectX API将显卡和luid对应起来,再和前面的PDH函数查询出的信息进行对应,即可实现显卡和显存占用信息的对应。

具体的方法为如下:

创建DXGIFactory,这里由于是jni调用的,所以我抛出了一个Java的异常。

1
2
3
4
5
if (CreateDXGIFactory(__uuidof(IDXGIFactory), reinterpret_cast<void**>(&p_dxgi_factory)) != S_OK)
{
env->ThrowNew(env->FindClass("java/lang/Exception"), "Create dxgi factory failed.");
return;
}

然后使用DXGIFactory遍历所有GPU,即可获取GPU名称和luid的对应关系

1
2
3
4
5
6
7
8
for (UINT i = 0;
p_dxgi_factory->EnumAdapters(i, &p_adapter) != DXGI_ERROR_NOT_FOUND;
++i)
{
DXGI_ADAPTER_DESC adapter_desc;
p_adapter->GetDesc(&adapter_desc);
gpu_name_adapter_map.insert(std::make_pair(adapter_desc.Description, adapter_desc));
}

将对应关系保存至gpu_name_adapter_map中。

Luid有两个成员,一个是LowPart,一个是HighPart

1
2
3
4
typedef struct _LUID {
DWORD LowPart;
LONG HighPart;
} LUID, *PLUID;

在Windows系统中是高位在前,低位在后,因此通过下面的方法即可获取luid,n2hexstr()方法是数字转16进制字符串。

1
2
3
4
5
6
7
8
9
10
11
template <typename I>
std::string n2hexstr(I w, size_t hex_len = sizeof(I) << 1)
{
static const char* digits = "0123456789ABCDEF";
std::string rc(hex_len, '0');
for (size_t i = 0, j = (hex_len - 1) * 4; i < hex_len; ++i, j -= 4)
rc[i] = digits[w >> j & 0x0f];
return rc;
}

auto&& luid_str = "0x" + n2hexstr(adapter_luid.HighPart) + "_" + "0x" + n2hexstr(adapter_luid.LowPart);

在我这台电脑上,例如想监控NVIDIA GeForce RTX 2060的显卡的显存占用。先通过gpu_name_adapter_map找到对应的luid,为0x00000000_0x00010fcf,然后调用query_pdh_val()函数获取所有显卡的显存占用,最后通过luid找出对应的显存占用即可。

至此,对显卡显存占用的查询监控方案结束。

在File-Engine中为了加速文件搜索的速度,加入了GPU加速的功能。通过显卡来进行并行运算,大幅提高搜索的速度。
由于显卡可以进行大量的并行运算,而每一个文件都是独立的,进行字符串匹配并不会相互干扰,因此GPU非常适合用来加速搜索。

具体的实现方法如下。

定义GPU加速接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package file.engine.dllInterface.gpu;

import java.util.function.BiConsumer;
import java.util.function.Supplier;

public interface IGPUAccelerator {

/**
* 重置搜索的标志
*/
void resetAllResultStatus();

/**
* 进行搜索,每当有一个结果匹配后调用resultCollector
* @param searchCase:有D F Full Case四种,分别代表只搜索文件夹(Directory),只搜索文件(File),全字匹配(Full),大小写敏感(Case Sensitive)
* @param isIgnoreCase 是否忽略大小写,当searchCase中存在case时该字段为true
* @param searchText 原始搜索关键字
* @param keywords 搜索关键字,将searchText用";"切割得到
* @param keywordsLowerCase 搜索关键字,全小写字母,内容与keywords相同,但字母为全小写
* @param isKeywordPath 保存keywords中的关键字是路径判断还是文件名判断
* @param maxResultNumber 最大匹配结果数量限制
* @param resultCollector 有一个结果匹配后的回调方法
*/
void match(String[] searchCase,
boolean isIgnoreCase,
String searchText,
String[] keywords,
String[] keywordsLowerCase,
boolean[] isKeywordPath,
int maxResultNumber,
BiConsumer<String, String> resultCollector);

/**
* 判断GPU加速是否可用
* @return true如果可以进行GPU加速
*/
boolean isGPUAvailableOnSystem();

/**
* 判断某一个缓存是否搜索完成
* @param key 缓存key
* @return true如果全部完成
*/
boolean isMatchDone(String key);

/**
* 获取某一个缓存搜索完成后匹配的结果数量
* @param key 缓存key
* @return 匹配的结果数量
*/
int matchedNumber(String key);

/**
* 停止搜索
*/
void stopCollectResults();

/**
* 判断缓存是否存在
* @return true如果至少保存了一个缓存
*/
boolean hasCache();

/**
* 判断某一个缓存是否存在
* @param key 缓存key
* @return true如果缓存名为key的缓存存在
*/
boolean isCacheExist(String key);

/**
* 添加缓存到GPU显存
* @param key 缓存key
* @param recordSupplier 字符串supplier,由GPU加速dll通过jni进行调用,防止字符串过多导致OOM
*/
void initCache(String key, Supplier<String> recordSupplier);

/**
* 向某个缓存添加数据
* @param key 缓存key
* @param records 待添加的数据
*/
void addRecordsToCache(String key, Object[] records);

/**
* 删除某一个缓存中的数据
* @param key 缓存key
* @param records 待删除的数据
*/
void removeRecordsFromCache(String key, Object[] records);

/**
* 删除某一个缓存
* @param key 缓存key
*/
void clearCache(String key);

/**
* 删除所有缓存
*/
void clearAllCache();

/**
* 缓存是否有效
* @param key 缓存key
* @return true如果缓存仍然有效,当addRecordsToCache()失败,表示当前缓存的空位已经不足,出现了数据丢失,缓存被标记为无效。
* @see #addRecordsToCache(String, Object[])
*/
boolean isCacheValid(String key);

/**
* 获取系统显存占用,用于在用户使用过多显存时释放缓存。
* @return 内存占用百分比,数值从0-100
*/
int getGPUMemUsage();

/**
* 初始化
*/
void initialize();

/**
* 释放所有资源
*/
void release();

/**
* 获取GPU设备名
* @return GPU设备名
*/
String[] getDevices();

/**
* 设置使用的GPU设备,deviceNum为getDevices()返回的数组的下标
* @param deviceNum GPU设备id
* @return true如果成功使用设备并初始化
*/
boolean setDevice(int deviceNum);
}

该接口定义了GPU加速的基本方法,以后如果有新的GPU计算框架可以用于加速,只需要重新实现这个接口即可。
目前,该接口有两个实现,分别使用CUDA和OpenCL来进行加速。然后采用一个统一的包装类GPUAccelerator进行管理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package file.engine.dllInterface.gpu;

import file.engine.configs.AllConfigs;
import file.engine.event.handler.EventManagement;
import file.engine.event.handler.impl.stop.RestartEvent;
import file.engine.utils.RegexUtil;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

public enum GPUAccelerator {
INSTANCE;
private static IGPUAccelerator gpuAccelerator;
private static final CudaAccelerator cudaAccelerator = CudaAccelerator.INSTANCE;
private static final OpenclAccelerator openclAccelerator = OpenclAccelerator.INSTANCE;

/**
* 之所以采用双重检验锁机制,是由于要实现懒加载,并且不能在加载类的时候进行加载
* 由于在事务管理器扫描@EventRegister@EventListener的阶段将会尝试加载所有类,此时配置中心还不可用
* 因此采用getInstance()方法来实现懒加载,在BootSystem事件发出后再进行初始化。
* @param isEnableGpuAccelerate
*/
record IsEnabledWrapper(boolean isEnableGpuAccelerate) {
private static volatile IsEnabledWrapper instance;

public static IsEnabledWrapper getInstance() {
if (instance == null) {
synchronized (IsEnabledWrapper.class) {
if (instance == null) {
instance = new IsEnabledWrapper(AllConfigs.getInstance().getConfigEntity().isEnableGpuAccelerate());
}
}
}
return instance;
}
}

enum Category {
CUDA("cuda"), OPENCL("opencl");
final String category;

Category(String category) {
this.category = category;
}

@Override
public String toString() {
return this.category;
}

static Category categoryFromString(String c) {
return switch (c) {
case "cuda" -> CUDA;
case "opencl" -> OPENCL;
default -> null;
};
}
}

...... GPU加速方法调用

/**
* key: 设备名
* value: [设备种类(cuda, opencl)];[设备id]
*
* @return map
*/
public Map<String, String> getDevices() {
LinkedHashMap<String, String> deviceMap = new LinkedHashMap<>();
getDeviceToMap(cudaAccelerator, deviceMap, Category.CUDA);
getDeviceToMap(openclAccelerator, deviceMap, Category.OPENCL);
return deviceMap;
}

private void getDeviceToMap(IGPUAccelerator igpuAccelerator, HashMap<String, String> deviceMap, Category category) {
if (igpuAccelerator.isGPUAvailableOnSystem()) {
var devices = igpuAccelerator.getDevices();
if (devices == null || devices.length == 0) {
return;
}
for (int i = 0; i < devices.length; ++i) {
var deviceName = devices[i];
if (deviceName.isBlank()) {
continue;
}
try {
if (!deviceMap.containsKey(deviceName)) {
deviceMap.put(deviceName, category + ";" + i);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}

public boolean setDevice(String deviceCategoryAndId) {
if (gpuAccelerator != null) {
// 切换GPU设备重启生效,运行中不允许切换
return true;
}
if (!IsEnabledWrapper.getInstance().isEnableGpuAccelerate()) {
return false;
}
if (deviceCategoryAndId.isEmpty()) {
if (cudaAccelerator.isGPUAvailableOnSystem()) {
cudaAccelerator.initialize();
if (cudaAccelerator.setDevice(0)) {
gpuAccelerator = cudaAccelerator;
return true;
}
}
if (openclAccelerator.isGPUAvailableOnSystem()) {
openclAccelerator.initialize();
if (openclAccelerator.setDevice(0)) {
gpuAccelerator = openclAccelerator;
return true;
}
}
return false;
}
String[] info = RegexUtil.semicolon.split(deviceCategoryAndId);
String deviceCategory = info[0];
int id = Integer.parseInt(info[1]);
var category = Category.categoryFromString(deviceCategory);
if (category != null) {
switch (category) {
case CUDA:
if (cudaAccelerator.isGPUAvailableOnSystem()) {
cudaAccelerator.initialize();
if (cudaAccelerator.setDevice(id)) {
gpuAccelerator = cudaAccelerator;
return true;
}
}
case OPENCL:
if (openclAccelerator.isGPUAvailableOnSystem()) {
openclAccelerator.initialize();
if (openclAccelerator.setDevice(id)) {
gpuAccelerator = openclAccelerator;
return true;
}
}
}
}
return false;
}

@SuppressWarnings("unused")
public static void sendRestartOnError0() {
System.err.println("GPU缓存出错,自动重启");
EventManagement.getInstance().putEvent(new RestartEvent());
}
}

下面介绍CUDA加速。

CUDA加速首先会加载cudaAccelerator.dll,如果加载成功,将会设置isCudaLoaded为true。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
enum CudaAccelerator implements IGPUAccelerator {
INSTANCE;

private static boolean isCudaLoaded;

static {
try {
System.load(Path.of("user/cudaAccelerator.dll").toAbsolutePath().toString());
isCudaLoaded = true;
} catch (UnsatisfiedLinkError | Exception e) {
e.printStackTrace();
isCudaLoaded = false;
}
}
}

初始化过程

加载完成后,将会调用setDevice(int deviceId)来选择GPU并进行初始化。
首先会调用initialize()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/*
* Class: file_engine_dllInterface_gpu_CudaAccelerator
* Method: initialize
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_file_engine_dllInterface_gpu_CudaAccelerator_initialize
(JNIEnv* env, jobject)
{
init_stop_signal();
init_cuda_search_memory();
init_str_convert();
//默认使用第一个设备,current_using_device=0
set_using_device(current_using_device);
if (env->GetJavaVM(&jvm) != JNI_OK)
{
env->ThrowNew(env->FindClass("java/lang/Exception"), "get JavaVM ptr failed.");
return;
}
set_jvm_ptr_in_kernel(jvm);
if (CreateDXGIFactory(__uuidof(IDXGIFactory), reinterpret_cast<void**>(&p_dxgi_factory)) != S_OK)
{
env->ThrowNew(env->FindClass("java/lang/Exception"), "create dxgi factory failed.");
}
IDXGIAdapter* p_adapter = nullptr;
for (UINT i = 0;
p_dxgi_factory->EnumAdapters(i, &p_adapter) != DXGI_ERROR_NOT_FOUND;
++i)
{
DXGI_ADAPTER_DESC adapter_desc;
p_adapter->GetDesc(&adapter_desc);
gpu_name_adapter_map.insert(std::make_pair(adapter_desc.Description, adapter_desc));
}
}

首先初始化停止搜索信号,当结果数量达到maxResults限制,停止信号将会被设置为true,在resetAllResultStatus()方法将会把停止信号重置为false。

随后将会初始化cuda相关内存,如存储搜索关键字,存储搜索过滤条件等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void init_cuda_search_memory()
{
gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&dev_search_case), sizeof(int)), true, nullptr);
gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&dev_search_text), MAX_PATH_LENGTH * sizeof(char)), true, nullptr);
gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&dev_keywords_length), sizeof(size_t)), true, nullptr);
gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&dev_is_keyword_path), sizeof(bool) * MAX_KEYWORDS_NUMBER), true,
nullptr);
gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&dev_is_ignore_case), sizeof(bool)), true, nullptr);
gpuErrchk(
cudaMalloc(reinterpret_cast<void**>(&dev_keywords), static_cast<size_t>(MAX_PATH_LENGTH * MAX_KEYWORDS_NUMBER)),
true, nullptr);
gpuErrchk(
cudaMalloc(reinterpret_cast<void**>(&dev_keywords_lower_case), static_cast<size_t>(MAX_PATH_LENGTH *
MAX_KEYWORDS_NUMBER)), true, nullptr);
}

最后初始化字符串转换器,字符串转换器可以在CPU中实现字符串从UTF-8编码转换到GB2312编码,然后实现从中文字符串转换到拼音,实现拼音搜索。

1
2
3
4
5
6
7
void init_str_convert()
{
init_gbk2_utf16_2();
init_gbk2_utf16_3();
// init_gbk2_utf16();
init_utf162_gbk();
}

初始化完成后将会获取jvm指针,用于在match方法开启cuda核函数后,使用多个线程收集处理结果,并将结果返回到Java虚拟机中。通过std::thread开启线程后用jvm指针使线程绑定到Java虚拟机从而调用Java方法。

最后初始化gpu_name_adapter,这是Windows中的directX API,在该项目中主要用于实现显存的监控,因此在此暂时不讨论。

到此初始化完成。

添加缓存步骤

接下来谈一下initCache方法初始化缓存。

首先看一下缓存的结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
/**
* \brief 存储数据结构
* remain_blank_num:当前dev_cache_str有多少空闲空间
* record_num:当前有多少个record
* record_hash:每个record的hash,用于判断重复
*/
using cache_data = struct cache_data
{
char* dev_strs = nullptr;
size_t* dev_str_addr = nullptr;
size_t* str_length = nullptr;
std::atomic_uint64_t remain_blank_num;
std::atomic_uint64_t record_num;
std::mutex lock;
concurrency::concurrent_unordered_set<size_t> record_hash;
};

/**
* \brief 缓存struct
* str_data:数据struct
* dev_output:字符串匹配后输出位置,下标与cache_data中一一对应,dev_output中数据为1代表匹配成功
* is_cache_valid:数据是否有效
* is_match_done:是否匹配全部完成
* is_output_done:是否已经存入容器 0 代表没有开始 1 代表正在收集 2代表完成
*/
using list_cache = struct cache_struct
{
cache_data str_data;
char* dev_output = nullptr;
size_t dev_output_bytes = 0;
bool is_cache_valid = false;
std::atomic_bool is_match_done;
std::atomic_int is_output_done;
unsigned matched_number = 0;
};

缓存的添加首先会调用record_supplier方法,获取每一个记录,将他们全部读入std::vector中,然后将缓存保存进入显存。

CUDA和OpenCL保存的实现方式略有不同,由于OpenCL中不允许使用size_t,因此无法实现在GPU核函数中对显存寻址。所以OpenCL中每一个字符串的都是定长,长度被定义在constant.h中的MAX_PATH_LENGTH宏中。

CUDA版本的保存方式会更加省显存空间。保存方式如下:
在将所有记录读出后,将会记录总共需要的字节数,然后分配三块内存(即cache_data中的dev_strs,dev_str_addr,str_length)
其中两块在显存中,一块为保存字符串的空间,另一块保存每一个字符串在第一块显存中的偏移量。第三块内存为主机内存,保存第一块显存中每个字符串的长度。

这样保存相比OpenCL的定长字符串省下了很多空间,但是由于需要在GPU核函数中进行两次寻址,因此效率相对比较低。
最后初始化标志位,缓存初始化完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/*
* Class: file_engine_dllInterface_gpu_CudaAccelerator
* Method: initCache
* Signature: (Ljava/lang/String;Ljava/util/function/Supplier;)V
*/
JNIEXPORT void JNICALL Java_file_engine_dllInterface_gpu_CudaAccelerator_initCache
(JNIEnv* env, jobject, jstring key_jstring, jobject record_supplier)
{
const jclass supplier_class = env->GetObjectClass(record_supplier);
const jmethodID get_function = env->GetMethodID(supplier_class, "get", "()Ljava/lang/Object;");
std::vector<std::string> records_vec;
unsigned record_count = 0;
size_t total_bytes = 0;
while (true)
{
const jobject record_from_supplier = env->CallObjectMethod(record_supplier, get_function);
if (record_from_supplier == nullptr)
{
break;
}
const auto jstring_val = reinterpret_cast<jstring>(record_from_supplier);
const auto record = env->GetStringUTFChars(jstring_val, nullptr);
if (const auto record_len = strlen(record); record_len < MAX_PATH_LENGTH)
{
records_vec.emplace_back(record);
total_bytes += record_len;
++total_bytes; // 每个字符串结尾 '\0'
++record_count;
}
env->ReleaseStringUTFChars(jstring_val, record);
env->DeleteLocalRef(record_from_supplier);
}
const auto _key = env->GetStringUTFChars(key_jstring, nullptr);
std::string key(_key);
auto cache = new list_cache;
cache->str_data.record_num = record_count;
cache->str_data.remain_blank_num = MAX_RECORD_ADD_COUNT;

const size_t total_results_size = static_cast<size_t>(record_count) + MAX_RECORD_ADD_COUNT;

const auto alloc_bytes = total_bytes + MAX_RECORD_ADD_COUNT * MAX_PATH_LENGTH;
gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&cache->str_data.dev_strs), alloc_bytes), true,
get_cache_info(key, cache).c_str());
gpuErrchk(cudaMemset(cache->str_data.dev_strs, 0, alloc_bytes), true, nullptr);

gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&cache->str_data.dev_str_addr), total_results_size * sizeof(size_t)),
true, nullptr);
gpuErrchk(cudaMemset(cache->str_data.dev_str_addr, 0, total_results_size * sizeof(size_t)), true, nullptr);

cache->str_data.str_length = new size_t[total_results_size];

gpuErrchk(cudaMalloc(reinterpret_cast<void**>(&cache->dev_output), total_results_size), true,
get_cache_info(key, cache).c_str());
cache->dev_output_bytes = total_results_size;
cache->is_cache_valid = true;
cache->is_match_done = false;
cache->is_output_done = 0;

cudaStream_t stream;
gpuErrchk(cudaStreamCreate(&stream), true, nullptr);
auto target_addr = cache->str_data.dev_strs;
auto save_str_addr_ptr = cache->str_data.dev_str_addr;
unsigned i = 0;
for (const std::string& record : records_vec)
{
const auto record_length = record.length();
gpuErrchk(cudaMemcpyAsync(target_addr, record.c_str(), record_length, cudaMemcpyHostToDevice, stream), true,
nullptr);
const auto str_address = reinterpret_cast<size_t>(target_addr);
//保存字符串在显存上的地址
gpuErrchk(cudaMemcpyAsync(save_str_addr_ptr, &str_address, sizeof(size_t), cudaMemcpyHostToDevice, stream),
true, nullptr);
cache->str_data.str_length[i] = record_length;
target_addr += record_length;
++target_addr;
++save_str_addr_ptr;
++i;
cache->str_data.record_hash.insert(hasher(record));
}
gpuErrchk(cudaStreamSynchronize(stream), true, nullptr);
gpuErrchk(cudaStreamDestroy(stream), true, nullptr);
cache_map.insert(std::make_pair(key, cache));
env->ReleaseStringUTFChars(key_jstring, _key);
env->DeleteLocalRef(supplier_class);
}

字符串匹配及收集

再来谈一下match方法,该方法实现对缓存中数据的字符串匹配,并将匹配后的结果保存进入Java容器中

首先等待清理缓存完成,防止在搜索时缓存被清除导致程序崩溃。然后生成搜索关键字和搜索过滤条件,随后打开多个线程收集GPU匹配结果,最后开启GPU核函数进行匹配,等待所有收集线程退出,即搜索完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/*
* Class: file_engine_dllInterface_gpu_CudaAccelerator
* Method: match
* Signature: ([Ljava/lang/String;ZLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;[ZILjava/util/function/BiConsumer;)V
*/
JNIEXPORT void JNICALL Java_file_engine_dllInterface_gpu_CudaAccelerator_match
(JNIEnv* env, jobject, jobjectArray search_case, jboolean is_ignore_case, jstring search_text,
jobjectArray keywords, jobjectArray keywords_lower, jbooleanArray is_keyword_path, jint max_results,
jobject result_collector)
{
if (cache_map.empty())
{
return;
}
wait_for_clear_cache();
std::lock_guard lock_guard(modify_cache_lock);
//生成搜索条件 search_case_vec
std::vector<std::string> search_case_vec;
if (search_case != nullptr)
{
generate_search_case(env, search_case_vec, search_case);
}
//生成搜索关键字 keywords_vec keywords_lower_vec is_keyword_path_ptr
std::vector<std::string> keywords_vec;
std::vector<std::string> keywords_lower_vec;
const auto keywords_length = env->GetArrayLength(keywords);
if (keywords_length > MAX_KEYWORDS_NUMBER)
{
fprintf(stderr, "too many keywords.\n");
return;
}
bool is_keyword_path_ptr[MAX_KEYWORDS_NUMBER]{ false };
const auto is_keyword_path_ptr_bool_array = env->GetBooleanArrayElements(is_keyword_path, nullptr);
for (jsize i = 0; i < keywords_length; ++i)
{
auto tmp_keywords_str = reinterpret_cast<jstring>(env->GetObjectArrayElement(keywords, i));
auto keywords_chars = env->GetStringUTFChars(tmp_keywords_str, nullptr);
#ifdef DEBUG_OUTPUT
std::cout << "keywords: " << keywords_chars << std::endl;
#endif
keywords_vec.emplace_back(keywords_chars);
env->ReleaseStringUTFChars(tmp_keywords_str, keywords_chars);
env->DeleteLocalRef(tmp_keywords_str);

tmp_keywords_str = reinterpret_cast<jstring>(env->GetObjectArrayElement(keywords_lower, i));
keywords_chars = env->GetStringUTFChars(tmp_keywords_str, nullptr);
keywords_lower_vec.emplace_back(keywords_chars);
env->ReleaseStringUTFChars(tmp_keywords_str, keywords_chars);
env->DeleteLocalRef(tmp_keywords_str);

#ifdef DEBUG_OUTPUT
std::cout << "is keyword path: " << static_cast<bool>(is_keyword_path_ptr_bool_array[i]) << std::endl;
#endif
is_keyword_path_ptr[i] = is_keyword_path_ptr_bool_array[i];
}
env->ReleaseBooleanArrayElements(is_keyword_path, is_keyword_path_ptr_bool_array, JNI_ABORT);
//复制全字匹配字符串 search_text
const auto search_text_chars = env->GetStringUTFChars(search_text, nullptr);
std::atomic_uint result_counter = 0;
std::vector<std::thread> collect_threads_vec;
collect_threads_vec.reserve(COLLECT_RESULTS_THREADS);
for (int i = 0; i < COLLECT_RESULTS_THREADS; ++i)
{
collect_threads_vec.emplace_back([&]
{
JNIEnv* thread_env = nullptr;
JavaVMAttachArgs args{ JNI_VERSION_10, nullptr, nullptr };
if (jvm->AttachCurrentThread(reinterpret_cast<void**>(&thread_env), &args) != JNI_OK)
{
fprintf(stderr, "get thread JNIEnv ptr failed");
return;
}
collect_results(thread_env, result_collector, result_counter, max_results, search_case_vec);
jvm->DetachCurrentThread();
});
}
//GPU并行计算
start_kernel(cache_map, search_case_vec, is_ignore_case, search_text_chars,
keywords_vec, keywords_lower_vec, is_keyword_path_ptr);
collect_results(env, result_collector, result_counter, max_results, search_case_vec);
for (auto&& each_thread : collect_threads_vec)
{
if (each_thread.joinable())
{
each_thread.join();
}
}
for (auto& [_, cache_val] : cache_map)
{
if (cache_val->is_output_done.load() != 2)
{
cache_val->is_output_done = 2;
}
}
env->ReleaseStringUTFChars(search_text, search_text_chars);
}

GPU核函数

首先通过核函数线程id获取对应的字符串,然后对字符串进行匹配,如果匹配完成将output设置为1(即缓存结构体中的dev_output),如果匹配失败则设置为0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/**
* TODO 核函数并未对参数做检查,所以如果数据库中包含不是文件路径的记录将会导致崩溃。
* 如 D: C: 这样的记录将会导致核函数失败
*/
__global__ void check(const size_t* str_address_records,
const size_t* total_num,
const int* search_case,
const bool* is_ignore_case,
char* search_text,
char* keywords,
char* keywords_lower_case,
const size_t* keywords_length,
const bool* is_keyword_path,
char* output,
const bool* is_stop_collect_var)
{
const size_t thread_id = GET_TID();
if (thread_id >= *total_num)
{
return;
}
if (*is_stop_collect_var)
{
output[thread_id] = 0;
return;
}
const auto path = reinterpret_cast<const char*>(str_address_records[thread_id]);
if (path == nullptr || !path[0])
{
output[thread_id] = 0;
return;
}
if (not_matched(path, *is_ignore_case, keywords, keywords_lower_case, static_cast<int>(*keywords_length),
is_keyword_path))
{
output[thread_id] = 0;
return;
}
if (*search_case == 0)
{
output[thread_id] = 1;
return;
}
if (*search_case & 4)
{
// 全字匹配
strlwr_cuda(search_text);
char file_name[MAX_PATH_LENGTH]{ 0 };
get_file_name(path, file_name);
strlwr_cuda(file_name);
if (strcmp_cuda(search_text, file_name) != 0)
{
output[thread_id] = 0;
return;
}
}
output[thread_id] = 1;
}

在not_matched()方法中,将会对字符串进行匹配,并尝试进行拼音匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__device__ bool not_matched(const char* path,
const bool is_ignore_case,
char* keywords,
char* keywords_lower_case,
const int keywords_length,
const bool* is_keyword_path)
{
for (int i = 0; i < keywords_length; ++i)
{
const bool is_keyword_path_val = is_keyword_path[i];
char match_str[MAX_PATH_LENGTH]{ 0 };
if (is_keyword_path_val)
{
get_parent_path(path, match_str);
}
else
{
get_file_name(path, match_str);
}
char* each_keyword;
if (is_ignore_case)
{
each_keyword = keywords_lower_case + i * static_cast<unsigned long long>(MAX_PATH_LENGTH);
strlwr_cuda(match_str);
}
else
{
each_keyword = keywords + i * static_cast<unsigned long long>(MAX_PATH_LENGTH);
}
if (!each_keyword[0])
{
continue;
}
if (strstr_cuda(match_str, each_keyword) == nullptr)
{
if (is_keyword_path_val || !is_str_contains_chinese(match_str))
{
return true;
}
char gbk_buffer[MAX_PATH_LENGTH * 2]{ 0 };
char* gbk_buffer_ptr = gbk_buffer;
// utf-8编码转换gbk
utf8_to_gbk(match_str, static_cast<unsigned>(strlen_cuda(match_str)), &gbk_buffer_ptr, nullptr);
char converted_pinyin[MAX_PATH_LENGTH * 6]{ 0 };
char converted_pinyin_initials[MAX_PATH_LENGTH]{ 0 };
convert_to_pinyin(gbk_buffer, converted_pinyin, converted_pinyin_initials);
if (strstr_cuda(converted_pinyin, each_keyword) == nullptr &&
strstr_cuda(converted_pinyin_initials, each_keyword) == nullptr)
{
return true;
}
}
}
return false;
}

结果收集

每个缓存中拥有一个is_output_done字段,初始为0,每个线程将会通过cas尝试将缓存的is_output_done设置为1,表示正在收集,当收集完成is_output_done会被设置为2表示收集完成。

当线程抢到某个缓存的收集权后将会开始进行收集。
GPU核函数匹配成功后将会把dev_output中对应的标志设置为1,通过dev_output可以拿到对应的字符串地址,然后取出字符串。

由于GPU核函数无法访问操作系统函数,因此在这里还需要过滤是否为文件夹,或是否为文件,过滤完成后将会调用_collect_func()调用match()方法的回调方法,将结果存入Java容器中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
void collect_results(JNIEnv* thread_env, jobject result_collector, std::atomic_uint& result_counter,
const unsigned max_results, const std::vector<std::string>& search_case_vec)
{
const jclass biconsumer_class = thread_env->GetObjectClass(result_collector);
const jmethodID collector = thread_env->GetMethodID(biconsumer_class, "accept",
"(Ljava/lang/Object;Ljava/lang/Object;)V");
bool all_complete;
const auto stop_func = [&]
{
return is_stop() || result_counter.load() >= max_results;
};
auto _collect_func = [&](const std::string& _key, char _matched_record_str[MAX_PATH_LENGTH],
unsigned* matched_number)
{
if (++result_counter >= max_results)
{
is_results_number_exceed = true;
}
auto record_jstring = thread_env->NewStringUTF(_matched_record_str);
auto key_jstring = thread_env->NewStringUTF(_key.c_str());
thread_env->CallVoidMethod(result_collector, collector, key_jstring, record_jstring);
thread_env->DeleteLocalRef(record_jstring);
thread_env->DeleteLocalRef(key_jstring);
++* matched_number;
};
do
{
//尝试退出
all_complete = true;
for (const auto& [key, val] : cache_map)
{
if (stop_func())
{
break;
}
if (!val->is_cache_valid)
{
continue;
}
if (!val->is_match_done.load())
{
//发现仍然有结果未计算完,设置退出标志为false,跳到下一个计算结果
all_complete = false;
continue;
}
if (int expected = 0; !val->is_output_done.compare_exchange_strong(expected, 1))
{
continue;
}
unsigned matched_number = 0;
//复制结果数组到host,dev_output下标对应dev_cache中的下标,若dev_output[i]中的值为1,则对应dev_cache[i]字符串匹配成功
const auto output_ptr = new char[val->dev_output_bytes];
//将dev_output拷贝到output_ptr
gpuErrchk(cudaMemcpy(output_ptr, val->dev_output, val->str_data.record_num, cudaMemcpyDeviceToHost), false,
"collect results failed");
for (size_t i = 0; i < val->str_data.record_num.load(); ++i)
{
if (stop_func())
{
break;
}
//dev_cache[i]字符串匹配成功
if (static_cast<bool>(output_ptr[i]))
{
char matched_record_str[MAX_PATH_LENGTH]{ 0 };
char* str_address;
gpuErrchk(
cudaMemcpy(&str_address, val->str_data.dev_str_addr + i, sizeof(size_t), cudaMemcpyDeviceToHost
), false, nullptr);
//拷贝GPU中的字符串到host
gpuErrchk(cudaMemcpy(matched_record_str, str_address, val->str_data.str_length[i],
cudaMemcpyDeviceToHost), false, "collect results failed");
// 判断文件和文件夹
if (search_case_vec.empty())
{
if (is_file_exist(matched_record_str))
{
_collect_func(key, matched_record_str, &matched_number);
}
}
else
{
if (std::find(search_case_vec.begin(), search_case_vec.end(), "f") != search_case_vec.end())
{
if (is_dir_or_file(matched_record_str) == 1)
{
_collect_func(key, matched_record_str, &matched_number);
}
}
else if (std::find(search_case_vec.begin(), search_case_vec.end(), "d") != search_case_vec.
end())
{
if (is_dir_or_file(matched_record_str) == 0)
{
_collect_func(key, matched_record_str, &matched_number);
}
}
else
{
if (is_file_exist(matched_record_str))
{
_collect_func(key, matched_record_str, &matched_number);
}
}
}
}
}
val->matched_number = matched_number;
val->is_output_done = 2;
delete[] output_ptr;
}
} while (!all_complete && !stop_func());
thread_env->DeleteLocalRef(biconsumer_class);
}

至此,GPU加速的核心方法基本介绍完成。OpenCL的版本除了字符串存储方式不同,其他方法几乎相同,只是使用了不同的框架进行实现,因此不再赘述。

        网上找了很多的资料,还是没有很清楚的了解Everything这款软件。所以自己研究了一下,并记录下来。

        Everything是一个搜索文件速度超快的软件,相比Windows自带的搜索功能,Everything可以做到在数十万文件中做到秒搜。

    本文来分析一下Everything背后的原理以及自己实现一个Everything。

Everything的工作原理

    Everything在第一次打开程序时会扫描整个磁盘,并建立一个索引库。需要注意的是,Everything并不是像Windows文件夹遍历那样一个文件一个文件的搜索并记录。而是通过NTFS文件系统的特性,MFT和USN journal。这也是Everything仅支持NTFS文件系统的原因。

    Master File Table (MFT)

    在NTFS文件系统中,有一个特殊的表,称为MFT表。所有文件夹和文件的名称都被存储在该表中,Everything通过遍历这个表的所有内容,实现在不遍历文件系统就能获取当前磁盘中的所有文件的名称和路径。

    USN journal

    NTFS的日志功能。所有对文件系统的修改操作都被记录在了一个journal日志文件中。Everything通过监控这个日志文件实现对文件修改的监控。

    文件查找

    通过字符串匹配算法从之前建立的索引中对字符串进行匹配,并显示文件名称和路径。

关于usn日志的相关信息,个人认为这篇文章写的比较清楚

NTFS USN日志记录读取历险记 - Irix的博客 | Irix’s Blog (atr.pub)

以及官方文档

Change Journals - Win32 apps | Microsoft Learn

Keeping an Eye on Your NTFS Drives: the Windows 2000 Change Journal Explained | Microsoft Learn

自己实现一个Everything

    我自己已经实现了一个完整的类似于Everything的软件,代码已经在GitHub上开源,欢迎各位去点个star呀。

XUANXUQAQ/File-Engine: An app launcher && efficiency tool (github.com)

一、建立数据库索引

    首先,我们需要做的就是对文件的路径进行索引。我这里选用SQLite。不知道Everything是使用了什么数据库才使他们的索引文件做到这么小,至于为什么这么快,推测可能是先将索引放入内存,索引可用后,程序在后台慢慢同步到硬盘的吧。

    搜索的核心方法是通过创建USN日志并进行读取,然后将数据还原成文件路径并写入数据库。

    1. 首先是获取驱动盘的句柄

    通过CreateFile方法打开磁盘,并返回磁盘句柄。

    关于CreateFile的用法详见MSDNCreateFileW function (fileapi.h) - Win32 apps | Microsoft Docs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
bool volume::get_handle()
{
// 为\\.\C:的形式
CString lp_file_name(_T("\\\\.\\c:"));
lp_file_name.SetAt(4, vol);


hVol = CreateFile(lp_file_name,
GENERIC_READ | GENERIC_WRITE, // 可以为0
FILE_SHARE_READ | FILE_SHARE_WRITE, // 必须包含有FILE_SHARE_WRITE
nullptr,
OPEN_EXISTING, // 必须包含OPEN_EXISTING, CREATE_ALWAYS可能会导致错误
FILE_ATTRIBUTE_READONLY, // FILE_ATTRIBUTE_NORMAL可能会导致错误
nullptr);


if (INVALID_HANDLE_VALUE != hVol)
{
return true;
}
auto&& info = std::wstring(L"create file handle failed. ") + lp_file_name.GetString() +
L"error code: " + std::to_wstring(GetLastError());
fprintf(stderr, "fileSearcherUSN: %ls", info.c_str());
return false;
}

    2. 之后开始创建USN日志

    通过DeviceIoControl方法向磁盘发送FSCTL_CREATE_USN_JOURNAL命令,创建USN日志。

    关于DeviceIoControl的用法,详见MSDNDeviceIoControl function (ioapiset.h) - Win32 apps | Microsoft Docs

DeviceIoControl的各个命令操作,详见MSDNVolume Management Control Codes - Win32 apps | Microsoft Docs

这里创建USN日志使用的是FSCTL_CREATE_USN命令FSCTL_CREATE_USN_JOURNAL - Win32 apps | Microsoft Docs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
bool volume::create_usn()
{
cujd.MaximumSize = 0; // 0表示使用默认值
cujd.AllocationDelta = 0; // 0表示使用默认值

DWORD br;
if (
DeviceIoControl(hVol, // handle to volume
FSCTL_CREATE_USN_JOURNAL, // dwIoControlCode
&cujd, // input buffer
sizeof(cujd), // size of input buffer
nullptr, // lpOutBuffer
0, // nOutBufferSize
&br, // number of bytes returned
nullptr) // OVERLAPPED structure
)
{
return true;
}
auto&& info = "create usn error. Error code: " + std::to_string(GetLastError());
fprintf(stderr, "fileSearcherUSN: %s\n", info.c_str());
return false;
}

    3. 获取USN日志信息

    还是通过DeviceIoControl方法发送FSCTL_QUERY_USN_JOURNAL命令,获取当前卷USN日志的各项信息,检查USN日志是否获取成功。

FSCTL_QUERY_USN_JOURNAL - Win32 apps | Microsoft Docs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
bool volume::get_usn_info()
{
DWORD br;
if (
DeviceIoControl(hVol, // handle to volume
FSCTL_QUERY_USN_JOURNAL, // dwIoControlCode
nullptr, // lpInBuffer
0, // nInBufferSize
&ujd, // output buffer
sizeof(ujd), // size of output buffer
&br, // number of bytes returned
nullptr) // OVERLAPPED structure
)
{
return true;
}
auto&& info = "query usn error. Error code: " + std::to_string(GetLastError());
fprintf(stderr, "fileSearcherUSN: %s\n", info.c_str());
return false;
}

    4. 获取 USN Journal 文件的基本信息

    仍然通过DeviceIoControl发送FSCTL_ENUM_USN_DATA遍历USN日志。

    FSCTL_ENUM_USN_DATA - Win32 apps | Microsoft Docs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
bool volume::get_usn_journal()
{
MFT_ENUM_DATA med;
med.StartFileReferenceNumber = 0;
med.LowUsn = ujd.FirstUsn;
med.HighUsn = ujd.NextUsn;

// 根目录
CString tmp(_T("C:"));
tmp.SetAt(0, vol);

constexpr auto BUF_LEN = sizeof(USN) + 0x100000; // 尽可能地大,提高效率;

CHAR* buffer = new CHAR[BUF_LEN];
DWORD usn_data_size;
pfrn_name pfrn_name;

while (0 != DeviceIoControl(hVol,
FSCTL_ENUM_USN_DATA,
&med,
sizeof med,
buffer,
BUF_LEN,
&usn_data_size,
nullptr))
{
DWORD dw_ret_bytes = usn_data_size - sizeof(USN);
// 找到第一个 USN 记录
auto usn_record = reinterpret_cast<PUSN_RECORD>(buffer + sizeof(USN));

while (dw_ret_bytes > 0)
{
// 获取到的信息
const CString cfile_name(usn_record->FileName, usn_record->FileNameLength / 2);
pfrn_name.filename = cfile_name;
pfrn_name.pfrn = usn_record->ParentFileReferenceNumber;
// frnPfrnNameMap[UsnRecord->FileReferenceNumber] = pfrnName;
frnPfrnNameMap.insert(std::make_pair(usn_record->FileReferenceNumber, pfrn_name));
// 获取下一个记录
const auto record_len = usn_record->RecordLength;
dw_ret_bytes -= record_len;
usn_record = reinterpret_cast<PUSN_RECORD>(reinterpret_cast<PCHAR>(usn_record) + record_len);
}
// 获取下一页数据
med.StartFileReferenceNumber = *reinterpret_cast<DWORDLONG*>(buffer);
}
delete[] buffer;
return true;
}

    这里我将获取到的信息放入了frnPfrnNameMap中

1
2
3
4
5
6
7
typedef struct _pfrn_name
{
DWORDLONG pfrn = 0;
CString filename;
} pfrn_name;

typedef unordered_map<DWORDLONG, pfrn_name> Frn_Pfrn_Name_Map;

    这里不得不说USN日志文件的数据结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
typedef struct {
DWORD RecordLength;
WORD MajorVersion;
WORD MinorVersion;
DWORDLONG FileReferenceNumber;
DWORDLONG ParentFileReferenceNumber;
USN Usn;
LARGE_INTEGER TimeStamp;
DWORD Reason;
DWORD SourceInfo;
DWORD SecurityId;
DWORD FileAttributes;
WORD FileNameLength;
WORD FileNameOffset;
WCHAR FileName[1];
} USN_RECORD_V2, *PUSN_RECORD_V2;

    这里我们需要用到FileReferenceNumber、ParentFileReferenceNumber、FileName。

    USN日志在Windows中也有对应的命令,使用fsutil就可以创建日志

1
2
fsutil usn createJournal C:
fsutil usn enumData 1 0 1 C:

    此时,可以看到USN日志的数据。

i0KMjz.jpeg

    包含文件参照号,父文件参照号,以及文件名。

    也就是上文提到的FileReferenceNumber、ParentFileReferenceNumber、FileName

    USN日志是一个树状结构,通过文件号和父文件号不断向上找,最终可以拼接出一个完整的路径。获得路径后即可存入数据库。

    最后删除USN日志。

1
fsutil usn deleteJournal /D C:

    5. 删除USN日志文件

    因为USN日志默认并不存在,所以在程序创建之后可以将它删除。

    仍然是通过DeviceIoControl方法发送命令。

    FSCTL_DELETE_USN_JOURNAL - Win32 apps | Microsoft Docs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
bool volume::delete_usn() const
{
DELETE_USN_JOURNAL_DATA dujd;
dujd.UsnJournalID = ujd.UsnJournalID;
dujd.DeleteFlags = USN_DELETE_FLAG_DELETE;
DWORD br;

if (DeviceIoControl(hVol,
FSCTL_DELETE_USN_JOURNAL,
&dujd,
sizeof(dujd),
nullptr,
0,
&br,
nullptr)
)
{
CloseHandle(hVol);
return true;
}
CloseHandle(hVol);
return false;
}

    6. 此时就可以通过遍历frnPfrnNameMap还原出每个文件的路径,并存入数据库

    通过不断向上遍历查找,最后拼接出文件的路径。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void volume::get_path(DWORDLONG frn, CString& output_path)
{
const auto end = frnPfrnNameMap.end();
while (true)
{
auto it = frnPfrnNameMap.find(frn);
if (it == end)
{
output_path = L":" + output_path;
return;
}
output_path = _T("\\") + it->second.filename + output_path;
frn = it->second.pfrn;
}
}

此时,我们已经实现了实现Everything的第一步,建立数据库索引。

二、监控文件的变化

监控文件的变化在这里我采用的是ReadDirectoryChangesW方法进行监控。ReadDirectoryChangesW function (winbase.h) - Win32 apps | Microsoft Docs

文件操作类型有四种,分别为 创建、更改、删除、重命名

对应到Windows API里是FILE_ACTION_ADDED、FILE_ACTION_MODIFIED、FILE_ACTION_REMOVED、FILE_ACTION_RENAMED_OLD_NAME。

通过监控这四个操作,我们可以实现监控文件的变化,并更新数据库索引。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
* 开始监控文件夹
*/
void monitor_path(const std::string& path)
{
const auto wPath = string2wstring(path);
DirectoryChangesReader dcr(wPath);
auto&& flag = stop_flag.at(path);
while (flag.load())
{
dcr.EnqueueReadDirectoryChanges();
if (const DWORD rv = dcr.WaitForHandles(); rv == WAIT_OBJECT_0)
{
for (auto&& res = dcr.GetDirectoryChangesResultW(); const auto& [action, data] : res)
{
switch (action)
{
case FILE_ACTION_ADDED:
case FILE_ACTION_RENAMED_NEW_NAME:
if (data.find(L"$RECYCLE.BIN") == std::wstring::npos)
{
std::wstring data_with_disk;
data_with_disk.append(wPath).append(data);
add_record(data_with_disk);
}
break;
case FILE_ACTION_REMOVED:
case FILE_ACTION_RENAMED_OLD_NAME:
if (data.find(L"$RECYCLE.BIN") == std::wstring::npos)
{
std::wstring data_with_disk;
data_with_disk.append(wPath).append(data);
delete_record(data_with_disk);
}
break;
case FILE_ACTION_MODIFIED:
default:
break;
}
}
}
Sleep(10);
}
}

DirectoryChangesReader

将ReadDirectoryChanges()方法的调用封装进DirectoryChangesReader对象中,当调用EnqueueReadDirectoryChanges()方法后,文件更改记录将会被写入缓存中。

然后执行GetDirectoryChangesResultW()方法,调用GetOverlappedResult()方法读取出文件更改信息,然后将所有信息保存进一个vector中,再返回。

GetOverlappedResult function (ioapiset.h) - Win32 apps | Microsoft Learn

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#pragma once
#include <Windows.h>
#include <string>
#include <thread>
#include <unordered_set>
#include <mutex>

// A base class for handles with different invalid values.
template <std::uintptr_t hInvalid>
class Handle
{
public:
Handle(const Handle&) = delete;

Handle(Handle&& rhs) noexcept :
hHandle(std::exchange(rhs.hHandle, hInvalid))
{
}

Handle& operator=(const Handle&) = delete;

Handle& operator=(Handle&& rhs) noexcept
{
std::swap(hHandle, rhs.hHandle);
return *this;
}

// converting to a normal HANDLE
operator HANDLE() const { return hHandle; }

protected:
Handle(HANDLE v) : hHandle(v)
{
// throw if we got an invalid handle
if (hHandle == reinterpret_cast<HANDLE>(hInvalid) || hHandle == INVALID_HANDLE_VALUE)
throw std::runtime_error("invalid handle");
}

~Handle()
{
if (hHandle != reinterpret_cast<HANDLE>(hInvalid)) CloseHandle(hHandle);
}

private:
HANDLE hHandle;
};

using InvalidNullptrHandle = Handle<(std::uintptr_t)nullptr>;

// A class for directory handles
class DirectoryHandleW : public InvalidNullptrHandle
{
public:
DirectoryHandleW(const std::wstring& dir) :
Handle(
CreateFileW(
dir.c_str(), FILE_LIST_DIRECTORY,
FILE_SHARE_READ | FILE_SHARE_DELETE | FILE_SHARE_WRITE,
nullptr, OPEN_EXISTING, FILE_FLAG_BACKUP_SEMANTICS |
FILE_FLAG_OVERLAPPED, nullptr)
)
{
}
};

// A class for event handles
class EventHandle : public InvalidNullptrHandle
{
public:
EventHandle() : Handle(CreateEvent(nullptr, true, false, nullptr))
{
}
};

// A stepping function for FILE_NOTIFY_INFORMATION*
bool StepToNextNotifyInformation(FILE_NOTIFY_INFORMATION*& cur)
{
if (cur->NextEntryOffset == 0) return false;
cur = reinterpret_cast<FILE_NOTIFY_INFORMATION*>(
reinterpret_cast<char*>(cur) + cur->NextEntryOffset
);
return true;
}

// A ReadDirectoryChanges support class
template <size_t Handles = 1, size_t BufByteSize = 4096>
class DirectoryChangesReader
{
public:
static_assert(Handles > 0, "There must be room for at least 1 HANDLE");
static_assert(BufByteSize >= sizeof(FILE_NOTIFY_INFORMATION) + MAX_PATH, "BufByteSize too small");
static_assert(BufByteSize % sizeof(DWORD) == 0, "BufByteSize must be a multiple of sizeof(DWORD)");

DirectoryChangesReader(const std::wstring& dirname) :
hDir(dirname),
ovl{},
hEv{},
handles{ hEv },
buffer{ std::make_unique<DWORD[]>(BufByteSize / sizeof(DWORD)) }
{
}

// A function to fill in data to use with ReadDirectoryChangesW
void EnqueueReadDirectoryChanges()
{
ovl = OVERLAPPED{};
ovl.hEvent = hEv;
const BOOL rdc = ReadDirectoryChangesW(
hDir,
buffer.get(),
BufByteSize,
TRUE,
FILE_NOTIFY_CHANGE_FILE_NAME | FILE_NOTIFY_CHANGE_DIR_NAME |
FILE_NOTIFY_CHANGE_ATTRIBUTES | FILE_NOTIFY_CHANGE_SIZE |
FILE_NOTIFY_CHANGE_LAST_WRITE | FILE_NOTIFY_CHANGE_LAST_ACCESS |
FILE_NOTIFY_CHANGE_CREATION | FILE_NOTIFY_CHANGE_SECURITY,
nullptr,
&ovl,
nullptr
);
if (rdc == 0) throw std::runtime_error("EnqueueReadDirectoryChanges failed");
}

// A function to get a vector of <Action>, <Filename> pairs
std::vector<std::pair<DWORD, std::wstring>>
GetDirectoryChangesResultW()
{
std::vector<std::pair<DWORD, std::wstring>> retval;

auto* fni = reinterpret_cast<FILE_NOTIFY_INFORMATION*>(buffer.get());

DWORD ovlBytesReturned;
if (GetOverlappedResult(hDir, &ovl, &ovlBytesReturned, TRUE))
{
do
{
retval.emplace_back(
fni->Action,
std::wstring{
fni->FileName,
fni->FileName + fni->FileNameLength / sizeof(wchar_t)
}
);
} while (StepToNextNotifyInformation(fni));
}
return retval;
}

// wait for the handles in the handles array
DWORD WaitForHandles()
{
constexpr DWORD wait_threshold = 5000;
return ::WaitForMultipleObjects(Handles, handles, false, wait_threshold);
}

// access to the handles array
HANDLE& operator[](size_t idx) { return handles[idx]; }
constexpr size_t handles_count() const { return Handles; }
private:
DirectoryHandleW hDir;
OVERLAPPED ovl;
EventHandle hEv;
HANDLE handles[Handles];
std::unique_ptr<DWORD[]> buffer; // DWORD-aligned
};

至此,实现Everything第二步,监控文件变化完成。

三、实现UI界面

    由于项目是我几年前写的,所以当时用了Java的Swing框架。上文写到的创建数据库索引的工具编译成了exe,由Java调用。而监控文件变化的工具编译成了dll,通过jni进行调用。

    1. 首先,制作一个搜索框

    使用JFrame创建窗口,JTextField放上面作为搜索框,下面放JLabel作为结果的显示就好,这里就不再贴代码。

    2. 监听用户输入,并查询数据库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
private void addTextFieldDocumentListener() {
textField.getDocument().addDocumentListener(new DocumentListener() {
private boolean isSendSignal;

@Override
public void insertUpdate(DocumentEvent e) {
changeFontOnDisplayFailed();
clearAllAndResetAll();
isSendSignal = setRunningMode();
if (isSendSignal) {
startTime = System.currentTimeMillis();
startSignal.set(true);
isSqlNotInitialized.set(true);
}
if (runningMode == Constants.Enums.RunningMode.PLUGIN_MODE && currentUsingPlugin != null) {
currentUsingPlugin.textChanged(getSearchBarText());
currentUsingPlugin.clearResultQueue();
}
}

@Override
public void removeUpdate(DocumentEvent e) {
changeFontOnDisplayFailed();
clearAllAndResetAll();
isSendSignal = setRunningMode();
if (getSearchBarText().isEmpty()) {
lastMousePositionX = 0;
lastMousePositionY = 0;
listResultsNum.set(0);
currentResultCount.set(0);
startTime = System.currentTimeMillis();
startSignal.set(false);
isSqlNotInitialized.set(false);
} else {
if (isSendSignal) {
startTime = System.currentTimeMillis();
startSignal.set(true);
isSqlNotInitialized.set(true);
}
if (runningMode == Constants.Enums.RunningMode.PLUGIN_MODE && currentUsingPlugin != null) {
currentUsingPlugin.textChanged(getSearchBarText());
currentUsingPlugin.clearResultQueue();
}
}
}

@Override
public void changedUpdate(DocumentEvent e) {
startTime = System.currentTimeMillis();
startSignal.set(false);
isSqlNotInitialized.set(false);
}
});
}

    这里采用记录startTime和startSignal的方式,通过线程池来检测搜索信号,并进行数据库的搜索。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/**
* 生成未格式化的sql
* 第一个map中key保存未格式化的sql,value保存表名称,第二个map为搜索结果的暂时存储容器
*
* @return map
*/
private LinkedHashMap<LinkedHashMap<String, String>, ConcurrentSkipListSet<String>> getNonFormattedSqlFromTableQueue() {
if (isDatabaseUpdated.get()) {
isDatabaseUpdated.set(false);
initPriority();
}
LinkedHashMap<LinkedHashMap<String, String>, ConcurrentSkipListSet<String>> sqlColumnMap = new LinkedHashMap<>();
if (priorityMap.isEmpty()) {
return sqlColumnMap;
}
initTableQueueByPriority();
int asciiSum = 0;
if (keywords.get() != null) {
for (String keyword : keywords.get()) {
int ascII = GetAscII.INSTANCE.getAscII(keyword);
asciiSum += Math.max(ascII, 0);
}
}
int asciiGroup = asciiSum / 100;
String firstTableName = "list" + asciiGroup;
if (searchCase.get() != null && Arrays.asList(searchCase.get()).contains("d")) {
LinkedHashMap<String, String> _priorityMap = new LinkedHashMap<>();
String _sql = "SELECT %s FROM " + firstTableName + " WHERE PRIORITY=" + 0;
_priorityMap.put(_sql, firstTableName);
tableQueue.stream().filter(each -> !each.equals(firstTableName)).forEach(each -> {
String sql = "SELECT %s FROM " + each + " WHERE PRIORITY=" + 0;
_priorityMap.put(sql, each);
});
ConcurrentSkipListSet<String> container;
container = new ConcurrentSkipListSet<>();
sqlColumnMap.put(_priorityMap, container);
} else {
for (Pair i : priorityMap) {
LinkedHashMap<String, String> eachPriorityMap = new LinkedHashMap<>();
String _sql = "SELECT %s FROM " + firstTableName + " WHERE PRIORITY=" + i.priority;
eachPriorityMap.put(_sql, firstTableName);
tableQueue.stream().filter(each -> !each.equals(firstTableName)).forEach(each -> {
String sql = "SELECT %s FROM " + each + " WHERE PRIORITY=" + i.priority;
eachPriorityMap.put(sql, each);
});
ConcurrentSkipListSet<String> container;
container = new ConcurrentSkipListSet<>();
sqlColumnMap.put(eachPriorityMap, container);
}
}
tableQueue.clear();
return sqlColumnMap;
}

    这一步生成所有的查询SQL,简单来说就是生成SELECT * FROM [数据表] 命令并执行,这里因为后面是异步多线程查询,所以为每一个SQL语句都分配了一个用来装数据的容器,后续再进行合并,提高效率,%s字符串格式化主要是为后来可能增加的字段留出升级空间。

    3. 字符串匹配

    运行SQL后,文件路径将被查出来,通过文件名匹配。

    这里我添加了路径搜索,以及是否忽略大小写和拼音搜索的功能。如果只是单纯的匹配字符串,直接用string.contains()或者string.indexof()即可。

    关于为什么不使用KMP算法而只是暴力匹配。因为文件路径中并不存在大量重复字符串,并且路径字符串一般都不会很长。此时使用KMP算法由于大量运算反而会起反效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/**
* 判断文件路径是否满足当前匹配结果(该方法由check方法使用),检查文件路径使用check方法。
*
* @param path 文件路径
* @param isIgnoreCase 是否忽略大小谢
* @return true如果匹配成功
* @see #check(String, String[], String, String[]) ;
*/
private static boolean notMatched(String path, boolean isIgnoreCase, String[] keywords) {
String matcherStrFromFilePath;
boolean isPath;
for (String eachKeyword : keywords) {
if (eachKeyword == null || eachKeyword.isEmpty()) {
continue;
}
char firstChar = eachKeyword.charAt(0);
if (firstChar == '/' || firstChar == File.separatorChar) {
//匹配路径
isPath = true;
Matcher matcher = RegexUtil.slash.matcher(eachKeyword);
eachKeyword = matcher.replaceAll(Matcher.quoteReplacement(""));
//获取父路径
matcherStrFromFilePath = FileUtil.getParentPath(path);
} else {
//获取名字
isPath = false;
matcherStrFromFilePath = FileUtil.getFileName(path);
}
//转换大小写
if (isIgnoreCase) {
matcherStrFromFilePath = matcherStrFromFilePath.toLowerCase();
eachKeyword = eachKeyword.toLowerCase();
}
//开始匹配
if (matcherStrFromFilePath.indexOf(eachKeyword) == -1) {
if (isPath) {
return true;
} else {
if (PinyinUtil.isStringContainChinese(matcherStrFromFilePath)) {
if (PinyinUtil.toPinyin(matcherStrFromFilePath, "").indexOf(eachKeyword) == -1) {
return true;
}
} else {
return true;
}
}
}
}
return false;
}

    4. 显示结果

    最后直接将搜索出来的结果显示在JLabel上就好了,这里就不再粘贴代码。

至此,一个最基本的Everything制作完成

项目源代码在Github开源。XUANXUQAQ/File-Engine: An app launcher && efficiency tool (github.com)

同步与线程安全

线程安全可以说是Java里老生常谈的问题了,包括Java的synchronize,Lock,以及他们的一些基本原理,monitor enter,monitor exit,和Java的AQS框架。

volatile被称为轻量级的线程同步工具。下面就来看一下volatile关键字的用法以及到底有什么作用。

单例模式

最常见的volatile的用法当然是单例模式的双重检验锁了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public class Singleton {  
private volatile static Singleton singleton;
private Singleton (){}
public static Singleton getSingleton() {
    if (singleton == null) {
    synchronized (Singleton.class) {
        if (singleton == null) {
        singleton = new Singleton();
        }
    }
    }
    return singleton;
}
}

双重检验锁,顾名思义,需要判断两次null。

第一次判断是null了之后,进入synchronize代码块,随后再次判断,如果仍然是null,那么创建对象,然后赋值给singleton,最后返回。这是一个线程安全的单例模式。

那么为什么要判断两次呢。

因为如果只判断一次就进入同步代码块,线程一进入同步代码块之后就被切换到线程二,线程二此时判断仍然是null,因此会开始等待线程一退出同步代码块。当线程一创建对象退出后之后,线程二进入,仍然会创建一个新的对象,此时就出现了两个对象。

但是为什么需要加上volatile才真正的线程安全呢。

这就是volatile的第一个功能,禁止代码重排序。

禁止代码重排序

代码重排序出现有很多原因,编译器优化可能导致指令被重排,现代CPU的流水线设计,乱序执行也已经是非常普遍的功能。

new操作本质是分为三步的,首先分配内存空间,然后再将对象初始化,最后让变量指向那个内存空间。

在Java字节码中也有所体现。

1
2
3
4
5
public class Test {
public void test() {
var testObj = new Test();
}
}

这一段简单的代码,字节码为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// class version 61.0 (61)
// access flags 0x21
public class Test {

// compiled from: Test.java

// access flags 0x1
public <init>()V
L0
LINENUMBER 1 L0
ALOAD 0
INVOKESPECIAL java/lang/Object.<init> ()V
RETURN
L1
LOCALVARIABLE this LTest; L0 L1 0
MAXSTACK = 1
MAXLOCALS = 1

// access flags 0x1
public test()V
L0
LINENUMBER 3 L0
NEW Test
DUP
INVOKESPECIAL Test.<init> ()V
ASTORE 1
L1
LINENUMBER 4 L1
RETURN
L2
LOCALVARIABLE this LTest; L0 L2 0
LOCALVARIABLE testObj LTest; L1 L2 1
MAXSTACK = 2
MAXLOCALS = 2
}

可以看到,这里首先调用了NEW字节码,获得了对象的一个引用,然后调用DUP字节码,将操作数栈顶复制并入栈,这时操作数栈有两个该对象的引用。然后弹出一个引用,并调用方法初始化实例,也就是构造函数以及代码块进行初始化,最后再调用ASTORE将引用赋值到局部变量testObj。

因此,在这里如果不使用volatile关键字,可能会导致赋值引用和初始化实例被重排序,如果第一个线程创建对象之后没有初始化就发生了线程切换,第二个线程在第一个判断中已经不为null,使用该对象时就可能出现问题。

关于volatile如何实现禁止指令重排,这里就不得不提到Java的happens-before原则。网上已经有了很多的资料,这里贴上一篇自认为讲的很不错的文章。

Java内存访问重排序的研究 - 美团技术团队 (meituan.com)

简单来说,禁止指令重排是通过内存屏障来实现。

在Java中定义了4中内存屏障,分别是

  • LoadLoad屏障:对于这样的语句Load1; LoadLoad; Load2,在Load2及后续读取操作要读取的数据被访问前,保证Load1要读取的数据被读取完毕。

  • StoreStore屏障:对于这样的语句Store1; StoreStore; Store2,在Store2及后续写入操作执行前,保证Store1的写入操作对其它处理器可见。

  • LoadStore屏障:对于这样的语句Load1; LoadStore; Store2,在Store2及后续写入操作被刷出前,保证Load1要读取的数据被读取完毕。

  • StoreLoad屏障:对于这样的语句Store1; StoreLoad; Load2,在Load2及后续所有读取操作执行前,保证Store1的写入对所有处理器可见。它的开销是四种屏障中最大的。在大多数处理器的实现中,这个屏障是个万能屏障,兼具其它三种内存屏障的功能。

jdk/orderAccess_linux_x86.hpp at master · openjdk/jdk (github.com)

1
2
3
4
inline void OrderAccess::loadload()   { compiler_barrier(); }
inline void OrderAccess::storestore() { compiler_barrier(); }
inline void OrderAccess::loadstore() { compiler_barrier(); }
inline void OrderAccess::storeload() { fence(); }

至于为什么storeload是fence()而其他的是compiler_barrier(),就和x86平台的invalidate queue和store buffer有关了,又会牵扯到MESI缓存一致性协议,这里就不展开了。

当上面单例模式加上volatile之后,变量的初始化和引用赋值将会被禁止重排序,这时就不会再发生上文所说的使用到未初始化完成的对象的问题了。

上面的文章中提到了通过Unsafe. putOrderedObject来对volatile进行优化。

重点的就是下面的代码

1
2
3
4
5
    public void create() {
SomeThing temp = new SomeThing();
unsafe.putOrderedObject(this, valueOffset, null); //将value赋null值只是一项无用操作,实际利用的是这条语句的内存屏障
object = temp;
}

当调用unsafe.putOrderedObject之后,上面的new操作和下面的object = temp之间就隔了一个StoreStore内存屏障,这时就不会出现new未初始化完成就赋值的问题了。

还有一种利用局部变量来减轻volatile影响的优化方案,在Spring中的单例模式中有很好地体现。

spring-framework/DefaultSingletonBeanRegistry.java at main · spring-projects/spring-framework (github.com)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Nullable
protected Object getSingleton(String beanName, boolean allowEarlyReference) {
// Quick check for existing instance without full singleton lock
Object singletonObject = this.singletonObjects.get(beanName);
if (singletonObject == null && isSingletonCurrentlyInCreation(beanName)) {
singletonObject = this.earlySingletonObjects.get(beanName);
if (singletonObject == null && allowEarlyReference) {
synchronized (this.singletonObjects) {
// Consistent creation of early reference within full singleton lock
singletonObject = this.singletonObjects.get(beanName);
if (singletonObject == null) {
singletonObject = this.earlySingletonObjects.get(beanName);
if (singletonObject == null) {
ObjectFactory<?> singletonFactory = this.singletonFactories.get(beanName);
if (singletonFactory != null) {
singletonObject = singletonFactory.getObject();
this.earlySingletonObjects.put(beanName, singletonObject);
this.singletonFactories.remove(beanName);
}
}
}
}
}
}
return singletonObject;
}

这里利用了一个singletonObject来当做局部变量,这时在代码中进行创建的时候就不会因为外部变量的volatile而导致代码中频繁出现内存屏障,提高了性能。

内存可见性

接下来看另外一个例子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import java.util.concurrent.TimeUnit;

public class Test {
private static boolean flag = true;

public static void main(String[] args) {

new Thread(() -> {
System.out.println("进入循环");
while (flag) {

}
System.out.println("退出循环");
}).start();

new Thread(() -> {
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
flag = false;
System.out.println("已经将flag设置为false");
}).start();
}
}

在这段代码中,理论上应该当第二个线程睡眠1秒后,将flag设置为false,第一个线程就应该退出循环。

但是第一个线程却没有退出。

iCEOtZ.jpeg

这里就需要提到Java的内存模型,也就是JMM(Java Memory Model)。

关于JMM,这里也有一篇文章讲的比较好。

Java内存模型(JMM)总结 - 知乎 (zhihu.com)

JMM规定,所有的变量都存储在主内存中,每一个线程都有一个私有的本地内存,线程对变量的操作必须在本地内存中进行。

因此上面的代码中,由于线程一将flag进行了缓存,被JIT优化之后,就只会读取寄存器中的值,因此不会退出循环。

加上volatile之后,每次读取都会从主存中刷新,就不会再出现这样的问题了。

泛型概述

泛型是现代编程语言中的重要特性,简单来说就是不必指定类型,可以写出非特定类型,模板化的代码,提高代码重用率。

泛型应用最广的地方应该就是容器类了。在Java的容器类中大量的使用了泛型。

例如ArrayList

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public class ArrayList<E> extends AbstractList<E>
implements List<E>, RandomAccess, Cloneable, java.io.Serializable
{
@java.io.Serial
private static final long serialVersionUID = 8683452581122892189L;

/**
* Default initial capacity.
*/
private static final int DEFAULT_CAPACITY = 10;

/**
* Shared empty array instance used for empty instances.
*/
private static final Object[] EMPTY_ELEMENTDATA = {};

/**
* Shared empty array instance used for default sized empty instances. We
* distinguish this from EMPTY_ELEMENTDATA to know how much to inflate when
* first element is added.
*/
private static final Object[] DEFAULTCAPACITY_EMPTY_ELEMENTDATA = {};

/**
* The array buffer into which the elements of the ArrayList are stored.
* The capacity of the ArrayList is the length of this array buffer. Any
* empty ArrayList with elementData == DEFAULTCAPACITY_EMPTY_ELEMENTDATA
* will be expanded to DEFAULT_CAPACITY when the first element is added.
*/
transient Object[] elementData; // non-private to simplify nested class access

/**
* The size of the ArrayList (the number of elements it contains).
*
* @serial
*/
private int size;

可以看到ArrayList存放数据本质就是一个Object数组elementData,因此,如果不使用泛型,直接存储Object。比如将String类型放入容器,那么在get函数取出元素时类型为Object,这时候就需要强制转换。

1
2
3
ArrayList list = new ArrayList();
list.add("test");
String str = (String) list.get(0);

如何这个容器中还存在其他类型的元素,那么取出元素时就很容易出现ClassCastException异常。

1
2
list.add(123);
str = (String) list.get(1);

当然,如果只写一个专门存储String或者Integer的ArrayList也可以,但是这样就需要给每一个类型都写单独编写,更别提还有自己写的类。

因此,泛型就出现了,泛型类可以在编译阶段就检查类型,这样就不会导致类型转换的异常。

下面是一个最简单的泛型类

1
2
3
4
5
6
7
8
9
10
11
public class Generic<T>{ 
private T val;

public Generic(T val) {
this.val = val;
}

public T getVal(){
return val;
}
}

泛型擦除

泛型擦除是指Java中的泛型只在编译期有效,在运行期间会被删除。

如下面这段代码

1
2
3
4
5
6
7
8
public class Foo {  
public void test(List<String> stringList){

}
public void test(List<Integer> integerList) {

}
}

这段代码会报错,方法不能重载,原因就是上面两个方法,在编译后被泛型擦除,最后都是

1
public void test(List) {}

因此不能区分两个函数。

泛型类的继承

泛型类的继承关系不是由泛型类型决定的,如List<Integer>和List<Number>,虽然Integer继承自Number,但是List<Integer>和List<Number>并没有继承关系。

要想使两个泛型类具有继承关系,只能使两个泛型类本身之间继承,或实现接口。

如上面的ArrayList就继承了AbstractList<E<以及List<E<接口。

泛型的逆变和协变

先从一个数组说起,Java的数组是协变的。

看下面这段代码

1
2
3
4
5
6
7
public class Test {
public static void main(String[] args) {
Number[] arr = new Integer[2];
arr[0] = 1;
arr[1] = 0.5;
}
}

这段代码在编译器并不会出错,但是一旦运行,将会抛出一个异常

1
2
Exception in thread "main" java.lang.ArrayStoreException: java.lang.Double
at Test.main(Test.java:5)

这是因为Integer是Number的子类型,因此Integer[]也是Number[]的子类型,这样的性质被称为协变,在编译器并没有检查出错误。

但是在运行时,jvm虚拟机发现这个arr其实是一个Integer类型的数组,不是Number类型,所以不能存放进入double类型的数字,因此抛出了一个异常。

泛型的不变性

因此,在吸取了上面的教训之后,泛型被设计为不变,也就是说,List<Integer>并不是List<Number>的子类型。

这样在编译器就可以检查出错误,防止运行期再报错。

但是这样就引入一个新的问题,如何才能实现协变呢。

协变在Java中还是很常用的,比如我只想要一个Fruit集合,里面存放着水果,但我不想管里面到底存放的是哪种水果。

1
2
3
public void consume(List<Fruit> list) {
......
}

这时,泛型的不变性就带来了麻烦,加入我现在有一个List<Apple>,因为List<Fruit>并不是List<Apple>的父类型,参数就传递不进去。

1
2
List<Apple> appleList = new ArrayList<Apple>;
consume(appleList); // 报错

因此,在泛型中如何实现协变就成为了一个问题。

还有一种情况,如果我们希望往List<Object>中放水果,使用一个produce函数将所有List<Apple>或者List<Banana>的元素全部添加到List<Object>,但又希望在produce函数中向容器添加非Fruit的其他元素时进行检查并报错,这时候就需要逆变。

泛型通配符

要实现泛型协变和逆变,这时通配符 ? 就派上用场了。

  • <? extends>实现了泛型的协变
  • <? super>实现了泛型的逆变

在上面的代码中,假如在consume函数中我们想传入参数,就需要把List<Fruit>改为List<? extends Fruit>。这样就不会产生报错了。

List<? extends Fruit>,其中<? extends Fruit>代表的类型为:Fruit及其子类型,此时传入List<Apple>就没有问题了。

但是当List<Apple>协变为List<? extends Fruit>之后,就不能往容器中再放入元素了。

原因在于,当容器协变后,List<? extends Fruit>中的类型不能再被确定为Apple,<? extends Fruit>虽然包含Apple,但是并不特指为Apple。因此,如果放入一个其他的类型,比如Banana,那么在使用上一个List<Apple>进行读取的时候就会出现类型转换错误。

同样的,如果希望往Fruit中放水果,就可以使用<? super Fruit>让List<Object>逆变为List<? super Fruit>,这样在函数中就可以调用add方法。

从上面的例子可以看出,extends确定了泛型的上界,而super确定了泛型的下界。

PECS

究竟什么时候使用extends,什么时候使用super。也就是PECS

PECS: producer-extends, consumer-super.

生产者使用extends,因为协变只可读取,不可写入。消费者使用super,因为super写入可以保证类型检查。

在Collections中的copy函数就很好地诠释了PECS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public static <T> void copy(List<? super T> dest, List<? extends T> src) {
int srcSize = src.size();
if (srcSize > dest.size())
throw new IndexOutOfBoundsException("Source does not fit in dest");

if (srcSize < COPY_THRESHOLD ||
(src instanceof RandomAccess && dest instanceof RandomAccess)) {
for (int i=0; i<srcSize; i++)
dest.set(i, src.get(i));
} else {
ListIterator<? super T> di=dest.listIterator();
ListIterator<? extends T> si=src.listIterator();
for (int i=0; i<srcSize; i++) {
di.next();
di.set(si.next());
}
}
}

在这里,src使用extends进行协变,只可读取,dest使用super进行逆变,保证写入的类型检查。

Java中最常用的锁ReentrantLock就是基于AQS来实现的。网上已经有很多的资料来讲解AQS的原理。因此这里只是写一下个人比较难以理解的点,用以学习。

这里贴一篇非常好的AQS的文章,讲的非常详细,看完了还是可以学到很多东西。

从ReentrantLock的实现看AQS的原理及应用 - 美团技术团队 (meituan.com)

这里就记录下ReentrantLock的非公平锁。

首先是acquire(int)函数,当tryAcquire函数获取锁失败后,将会把线程加入等待队列。

1
2
3
4
public final void acquire(int arg) {
if (!tryAcquire(arg) && acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
selfInterrupt();
}

关键就在于这个acquireQueued函数,addWaiter()函数就是新建一个节点,然后将节点放入等待队列的末尾,上面放的那篇文章已经非常详细,这里就不写了。

接下来看看acquireQueued()函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/**
* Acquires in exclusive uninterruptible mode for thread already in
* queue. Used by condition wait methods as well as acquire.
*
* @param node the node
* @param arg the acquire argument
* @return {@code true} if interrupted while waiting
*/
@ReservedStackAccess
final boolean acquireQueued(final Node node, int arg) {
boolean failed = true;
try {
boolean interrupted = false;
for (;;) {
final Node p = node.predecessor();
if (p == head && tryAcquire(arg)) {
setHead(node);
p.next = null; // help GC
failed = false;
return interrupted;
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
interrupted = true;
}
} finally {
if (failed)
cancelAcquire(node);
}
}

这里可以看到进入函数之后,会进入自旋,首先找到前驱节点,如果前驱节点已经是head虚节点了,即当前线程可以获得锁,则调用tryAcquire()函数尝试获取锁。获取成功则将当前节点设置为虚节点,然后退出自旋,返回线程是否被中断。

如果前驱节点不是head,那么将会进入shouldParkAfterFailedAcquire()函数,该函数用于判断线程是否应该被阻塞。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/**
* Checks and updates status for a node that failed to acquire.
* Returns true if thread should block. This is the main signal
* control in all acquire loops. Requires that pred == node.prev.
*
* @param pred node's predecessor holding status
* @param node the node
* @return {@code true} if thread should block
*/
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
int ws = pred.waitStatus;
if (ws == Node.SIGNAL)
/*
* This node has already set status asking a release
* to signal it, so it can safely park.
*/
return true;
if (ws > 0) {
/*
* Predecessor was cancelled. Skip over predecessors and
* indicate retry.
*/
do {
node.prev = pred = pred.prev;
} while (pred.waitStatus > 0);
pred.next = node;
} else {
/*
* waitStatus must be 0 or PROPAGATE. Indicate that we
* need a signal, but don't park yet. Caller will need to
* retry to make sure it cannot acquire before parking.
*/
compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
}
return false;
}

进入后先检测前驱节点的状态ws。

下面是waitStatus的常量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/** Marker to indicate a node is waiting in exclusive mode */
static final Node EXCLUSIVE = null;

/** waitStatus value to indicate thread has cancelled */
static final int CANCELLED = 1;
/** waitStatus value to indicate successor's thread needs unparking */
static final int SIGNAL = -1;
/** waitStatus value to indicate thread is waiting on condition */
static final int CONDITION = -2;
/**
* waitStatus value to indicate the next acquireShared should
* unconditionally propagate
*/
static final int PROPAGATE = -3;
  1. 当前驱节点已经是唤醒状态,也就是Node.SIGNAL的情况下,当前线程就可以直接阻塞,返回true。

  2. 如果ws大于0,即前驱节点为取消状态,那么就进入do-while循环,不断寻找前驱节点,直到找到一个不是取消状态的节点。然后将不是取消状态的节点的next直接指向当前节点(也就是直接跳过中间被取消的节点)。

  3. 否则代表其他状态,则通过cas尝试将状态设置为SIGNAL。

当返回true以后,将会调用parkAndCheckInterrupt()函数,这里进入后会调用unsafe的park方法,将线程阻塞。

1
2
3
4
5
6
7
8
9
10
11
12
private final boolean parkAndCheckInterrupt() {
LockSupport.park(this);
return Thread.interrupted();
}

// LockSupport.java
public static void park(Object blocker) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(false, 0L);
setBlocker(t, null);
}

这里将会设置阻塞对象parkBlocker,这是一个Thread类的私有成员。

1
2
3
4
private static void setBlocker(Thread t, Object arg) {
// Even though volatile, hotspot doesn't need a write barrier here.
UNSAFE.putObject(t, parkBlockerOffset, arg);
}

设置完成后就调用UNSAFE.park()阻塞线程。

这里的parkBlocker是用来记录线程是被哪个对象阻塞的,用于线程监控和分析,通过LockSupport.getBlocker()函数就可以获取parkBlocker。

软件杯经过了几个月的竞赛和答辩过后,总算是获得了一个相对比较满意的成绩。

我们组的名字叫P5天下第一,回过头看这个名字实在是有种说不出的感觉(索尼罪大恶极),我的P5都还在PS4里吃灰,结果后面直接上了XGP,还是P5R,只能说《Only On PlayStation For A While》。

回到正题。首先还是来讲讲数据库和后端的方案吧。

下面是整体的架构图设计

911cG.png

Pest-Identification-Provider

数据库其实很简单,就是一个用户表,还有害虫的目科属种的信息表,以及关联表。

91NO1.png

后端采用spring cloud + nacos的方案。分为provider和consumer,然后通过spring gateway网关进行负载均衡。

9iZGD.jpg

YoloV4图像识别模块也通过gunicorn实现了负载均衡。

用户登录后台采用的是jwtToken + Spring security实现权限校验。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package com.rjgc.configs;

/**
* @author zhaoyunjie
* @date 2021-04-09 18:39
*/
@EnableWebSecurity
@Configuration
public class SpringSecurityConfig extends WebSecurityConfigurerAdapter {

@Autowired
private CustomUserDetainsService userDetainsService;

@Autowired
private CsrfProperties csrfProperties;

@Autowired
private CustomCorsFilter customCorsFilter;

@Autowired
private JwtTokenUtils jwtTokenUtils;

@Autowired
@Qualifier("userServiceImpl")
private UserService userService;

@Override
protected void configure(HttpSecurity http) throws Exception {
http.cors().and()
//添加cors响应头
.addFilterAfter(customCorsFilter, HeaderWriterFilter.class)
//添加token登录过滤器
.addFilterBefore(new TokenLoginFilter(authenticationManager(), jwtTokenUtils), UsernamePasswordAuthenticationFilter.class)
//添加token认证过滤器
.addFilterBefore(new TokenAuthenticationFilter(authenticationManager(), jwtTokenUtils, userService), BasicAuthenticationFilter.class).httpBasic()
.and()
//关闭session
.sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS)
.and()
.authorizeRequests()
.antMatchers(HttpMethod.OPTIONS).permitAll()
.antMatchers("/login").permitAll()
.antMatchers(HttpMethod.GET, "/species/*").permitAll()
.antMatchers("/imgFake").permitAll()
.anyRequest().hasRole("admin");
if (csrfProperties.getCsrfDisabled()) {
http.csrf().disable();
}
}

@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth.userDetailsService(userDetainsService).passwordEncoder(passwordEncoder());
}

@Override
public void configure(WebSecurity web) throws Exception {
web.ignoring()
.antMatchers("/v2/api-docs",
"/swagger-resources/configuration/ui",
"/swagger-resources",
"/swagger-resources/configuration/security",
"/swagger-ui.html",
"/webjars/**");
}

@Bean
public BCryptPasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
}

在spring security中添加了两个过滤器,一个Token登录过滤器,一个Token认证过滤器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package com.rjgc.filters;

/**
* @author zhaoyunjie
* @date 2021-04-15 19:56
*/
public class TokenLoginFilter extends UsernamePasswordAuthenticationFilter {
private final AuthenticationManager authenticationManager;
private final JwtTokenUtils jwtTokenUtils;

public TokenLoginFilter(AuthenticationManager authenticationManager, JwtTokenUtils jwtTokenUtils) {
this.authenticationManager = authenticationManager;
this.jwtTokenUtils = jwtTokenUtils;
this.setPostOnly(false);
this.setRequiresAuthenticationRequestMatcher(new AntPathRequestMatcher("/login", "POST"));
}

@Override
public Authentication attemptAuthentication(HttpServletRequest req, HttpServletResponse res) throws AuthenticationException {
Map<String, String[]> map = req.getParameterMap();
String username = map.get("username")[0];
String password = map.get("password")[0];
User user = new User(username, password, new ArrayList<>());
return authenticationManager.authenticate(new UsernamePasswordAuthenticationToken(user.getUsername(), user.getPassword(), new ArrayList<>()));
}

/**
* 登录成功
*/
@Override
protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication auth) {
User user = (User) auth.getPrincipal();
String token = jwtTokenUtils.generateToken(user.getUsername());
HashMap<String, Object> map = new HashMap<>();
map.put("token", token);
map.put("username", user.getUsername());
ResponseUtils.out(response, ResBody.success(map));
}

/**
* 登录失败
*/
@Override
protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) {
ResponseUtils.out(response, ResBody.error(new BizException(ExceptionsEnum.LOGIN_FAILED)));
}
}

TokenLoginFilter中有一个JwtTokenUtils,当尝试登录时,将会先调用attemptAuthentication函数,对比账号和密码后将会调用successfulAuthentication函数根据用户名生成一个token,然后返回到前端。

密码采用默认的BCryptPasswordEncoder进行加密。

1
2
3
4
@Bean
public BCryptPasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}

数据库采用的是MySQL + mybatis plus。

整体就是采用MVC架构,在controller中定义接口,然后在service中进行处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package com.rjgc.controller;

/**
* @author zhaoyunjie
* @date 2021-04-08 13:18
*/
@RestController
@RequestMapping("/family")
public class FamilyController {

@Qualifier("familyServiceImpl")
@Autowired
private FamilyService familyService;

@Autowired
@Qualifier("familyGenusVoServiceImpl")
private FamilyGenusVoService familyGenusVoService;

......

/**
* 通过id查询科
*
* @param id id
* @return 结果list
*/
@GetMapping
@ApiOperation("根据id查询科")
public ResBody<Map<String, Object>> selectFamiliesById(@RequestParam int id) {
return ResBody.success(familyVoService.selectFamiliesById(id));
}

/**
* 查询所有科
*
* @param pageNum 当前查询起始页
* @param pageSize 需要的页数
* @return 结果list
*/
@GetMapping("all")
@ApiOperation("查询所有科")
public ResBody<Map<String, Object>> selectAllFamilies(@RequestParam int pageNum, @RequestParam int pageSize) {
return ResBody.success(familyVoService.selectAllFamilies(pageNum, pageSize));
}

@GetMapping("name")
@ApiOperation("根据名称查询")
public ResBody<Map<String, Object>> selectFamiliesByName(@RequestParam int pageNum, @RequestParam int pageSize, @RequestParam String name) {
return ResBody.success(familyVoService.selectFamiliesByName(pageNum, pageSize, name));
}

......

在这里定义了一个统一的返回值ResBody,这样在前端处理会更方便,并且不需要处理null值的问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package com.rjgc.exceptions;

import lombok.Data;

import java.util.HashMap;
import java.util.Map;

/**
* @author zhaoyunjie
* @date 2021-04-09 10:42
*/
@Data
public class ResBody<T> {
/**
* 响应代码
*/
private int code;

/**
* 响应消息
*/
private String message;

/**
* 响应结果
*/
private Object result;

private static Map<String, Object> RESULT_HOLDER = new HashMap<>();

static {
RESULT_HOLDER.put("pages", 1);
RESULT_HOLDER.put("data", "");
}

/**
* 成功
*
* @return resBody
*/
public static <T> ResBody<T> success() {
return success(null);
}

/**
* 成功
*
* @param data data
* @return resBody
*/
public static <T> ResBody<T> success(T data) {
ResBody<T> rb = new ResBody<>();
rb.setCode(ExceptionsEnum.SUCCESS.getResCode());
rb.setMessage(ExceptionsEnum.SUCCESS.getResMsg());
rb.setResult(data);
return rb;
}

/**
* 失败
*/
public static <T> ResBody<T> error(BizException errorInfo) {
ResBody<T> rb = new ResBody<>();
rb.setCode(errorInfo.getExp().getResCode());
rb.setMessage(errorInfo.getExp().getResMsg());
rb.setResult(RESULT_HOLDER);
return rb;
}

public static <T> ResBody<T> error(int errorCode, String errorInfo) {
ResBody<T> rb = new ResBody<>();
rb.setCode(errorCode);
rb.setMessage(errorInfo);
rb.setResult(RESULT_HOLDER);
return rb;
}

@Override
public String toString() {
return "ResBody{" +
"code=" + code +
", message='" + message + '\'' +
", result=" + result +
'}';
}
}

项目中还使用了全局异常处理中心来对不同的异常进行处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package com.rjgc.exceptions;

import lombok.extern.slf4j.Slf4j;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import javax.servlet.http.HttpServletRequest;

/**
* @author zhaoyunjie
* @date 2021-04-09 10:38
* <p>
* 全局异常处理
*/
@RestControllerAdvice
@Slf4j
public class GlobalExceptionHandler {

@ExceptionHandler(value = BizException.class)
public ResBody<BizException> bizExceptionHandler(HttpServletRequest req, BizException e) {
log.debug("Message: " + e.getMessage());
log.debug("Cause: " + e.getCause());
return ResBody.error(e);
}

@ExceptionHandler(value = NullPointerException.class)
public ResBody<NullPointerException> nullPointerExceptionHandler(HttpServletRequest req, NullPointerException e) {
e.printStackTrace();
log.debug("Message: " + e.getMessage());
log.debug("Cause: " + e.getCause());
return ResBody.error(50001, "空指针异常");
}

@ExceptionHandler(value = DuplicateKeyException.class)
public ResBody<DuplicateKeyException> duplicateKeyExceptionHandler(HttpServletRequest req, DuplicateKeyException e) {
log.debug("Message: " + e.getMessage());
log.debug("Cause: " + e.getCause());
return ResBody.error(new BizException(ExceptionsEnum.INVALID_ID));
}

@ExceptionHandler(value = Exception.class)
public ResBody<Exception> exceptionHandler(HttpServletRequest req, Exception e) {
e.printStackTrace();
log.error("Message: " + e.getMessage());
log.error("Cause: " + e.getCause());
return ResBody.error(new BizException(ExceptionsEnum.DATABASE_FAILED));
}
}

以上基本上是provider的设计方案。

Pest-Identification-Consumer

接下来在consumer端,其实就是对provider的进一步封装。通过OpenFegin来进行远程调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@FeignClient(value = "pest-identification-provider", configuration = FeignConfig.class)
public interface FamilyClient {

@GetMapping("/family")
ResBody<Map<String, Object>> selectFamiliesById(@RequestParam int id);

@GetMapping("/family/all")
ResBody<Map<String, Object>> selectAllFamilies(@RequestParam int pageNum, @RequestParam int pageSize);

@GetMapping("/family/name")
ResBody<Map<String, Object>> selectFamiliesByName(@RequestParam int pageNum, @RequestParam int pageSize, @RequestParam String name);

@PostMapping("/family/{orderId}")
ResBody<Integer> insertFamily(@PathVariable int orderId, @RequestBody Family family);

@PutMapping("/family")
ResBody<Integer> updateFamily(@RequestParam int orderId, @RequestBody Family newFamily);

@DeleteMapping("/family")
ResBody<Integer> deleteFamilyById(@RequestParam int id);
}

Pest-Identification-Gateway

网关这里其实Spring已经给出了很方便的解决方案。只需要在配置文件中配好负载均衡,在nacos就可以实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
server:
port: 8081

spring:
application:
name: pest-identification-gateway

cloud:
nacos:
discovery:
server-addr: localhost:8848
config:
enabled: false

gateway:
discovery:
locator:
enabled: true
lower-case-service-id: true
routes:
- id: consumer
uri: lb://pest-identification-consumer
predicates:
- Path=/**
httpclient:
connect-timeout: 3000
response-timeout: 10000

YoloV4网络

这次竞赛有一个很大的难题,也就是图像识别的问题。这里我采用的是YoloV4框架来进行目标检测。由于这次竞赛还需要实现离线识别,因此在安卓客户端使用了YoloV5框架来实现断网情况下的识别。

首先说说YoloV4框架,在这次竞赛中,我使用了flask框架来对外提供识别api。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@app.route('/startPredict', methods=['POST'])
def start_predict():
"""
开始对img/${ip}中的图片进行预测,返回预测结果
:return:
"""
ret = {}
each_statistics: dict
if not pictures:
return resp_utils.error('还未上传文件')
for each in pictures:
info = load_index(each)
if not info:
each_statistics, base64_str = predict.predict_img(each)
ret[os.path.basename(each)] = {'statistics': each_statistics, 'img': base64_str}
if 'error' not in each_statistics:
save_index(each, base64_str, each_statistics)
else:
each_statistics = info['statistics']
base64_str = info['img']
ret[os.path.basename(each)] = {'statistics': each_statistics, 'img': base64_str}
pictures.clear()
return resp_utils.success(ret)

YoloV4框架采用pytorch实现。这里先贴一张YoloV4的架构图。图片来自 睿智的目标检测32——TF2搭建YoloV4目标检测平台(tensorflow2)_Bubbliiiing的博客-CSDN博客

这位Bubbliiiing真的让我学到了很多,他在B站也有关于YoloV4的教程,看了他的视频,我觉得受益匪浅。Bubbliiiing的个人空间_哔哩哔哩_bilibili

9nwxC.png

具体的介绍可以参考这里

YOLOV4 pytorch实现流程 - 知乎 (zhihu.com)

首先是实现CSPdarknet,定义MISH激活函数,以及实现基本的卷积操作,最后再实现一个完整的CSPdarknet。

这里不得不说pytorch实在是太方便了,只要继承nn.Module,然后写forward就可以了。之前TensorFlow折腾了好久,不同的backend,不同的版本居然都不兼容,坑太多了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


# MISH激活函数
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()

def forward(self, x):
return x * torch.tanh(F.softplus(x))


# Conv2d + BatchNormalization + Mish
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(BasicConv, self).__init__()

self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size // 2, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = Mish()

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
return x


# CSPdarknet
class Resblock(nn.Module):
def __init__(self, channels, hidden_channels=None):
super(Resblock, self).__init__()

if hidden_channels is None:
hidden_channels = channels

self.block = nn.Sequential(
BasicConv(channels, hidden_channels, 1),
BasicConv(hidden_channels, channels, 3)
)

def forward(self, x):
return x + self.block(x)


# CSPdarknet的结构块
class Resblock_body(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, first):
super(Resblock_body, self).__init__()
# ----------------------------------------------------------------#
# 利用一个步长为2x2的卷积块进行高和宽的压缩
# ----------------------------------------------------------------#
self.downsample_conv = BasicConv(in_channels, out_channels, 3, stride=2)

if first:
# --------------------------------------------------------------------------#
# 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
# --------------------------------------------------------------------------#
self.split_conv0 = BasicConv(out_channels, out_channels, 1)

# ----------------------------------------------------------------#
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
# ----------------------------------------------------------------#
self.split_conv1 = BasicConv(out_channels, out_channels, 1)
self.blocks_conv = nn.Sequential(
Resblock(channels=out_channels, hidden_channels=out_channels // 2),
BasicConv(out_channels, out_channels, 1)
)

self.concat_conv = BasicConv(out_channels * 2, out_channels, 1)
else:
# --------------------------------------------------------------------------#
# 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
# --------------------------------------------------------------------------#
self.split_conv0 = BasicConv(out_channels, out_channels // 2, 1)

# ----------------------------------------------------------------#
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
# ----------------------------------------------------------------#
self.split_conv1 = BasicConv(out_channels, out_channels // 2, 1)
self.blocks_conv = nn.Sequential(
*[Resblock(out_channels // 2) for _ in range(num_blocks)],
BasicConv(out_channels // 2, out_channels // 2, 1)
)

self.concat_conv = BasicConv(out_channels, out_channels, 1)

def forward(self, x):
x = self.downsample_conv(x)

x0 = self.split_conv0(x)

x1 = self.split_conv1(x)
x1 = self.blocks_conv(x1)

# ------------------------------------#
# 将大残差边再堆叠回来
# ------------------------------------#
x = torch.cat([x1, x0], dim=1)
# ------------------------------------#
# 最后对通道数进行整合
# ------------------------------------#
x = self.concat_conv(x)

return x


# CSPdarknet53
class CSPDarkNet(nn.Module):
def __init__(self, layers):
super(CSPDarkNet, self).__init__()
self.inplanes = 32
# 416,416,3 -> 416,416,32
self.conv1 = BasicConv(3, self.inplanes, kernel_size=3, stride=1)
self.feature_channels = [64, 128, 256, 512, 1024]

self.stages = nn.ModuleList([
# 416,416,32 -> 208,208,64
Resblock_body(self.inplanes, self.feature_channels[0], layers[0], first=True),
# 208,208,64 -> 104,104,128
Resblock_body(self.feature_channels[0], self.feature_channels[1], layers[1], first=False),
# 104,104,128 -> 52,52,256
Resblock_body(self.feature_channels[1], self.feature_channels[2], layers[2], first=False),
# 52,52,256 -> 26,26,512
Resblock_body(self.feature_channels[2], self.feature_channels[3], layers[3], first=False),
# 26,26,512 -> 13,13,1024
Resblock_body(self.feature_channels[3], self.feature_channels[4], layers[4], first=False)
])

self.num_features = 1
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x):
x = self.conv1(x)

x = self.stages[0](x)
x = self.stages[1](x)
out3 = self.stages[2](x)
out4 = self.stages[3](out3)
out5 = self.stages[4](out4)

return out3, out4, out5


def darknet53(pretrained):
model = CSPDarkNet([1, 2, 8, 8, 4])
if pretrained:
if isinstance(pretrained, str):
model.load_state_dict(torch.load(pretrained))
else:
raise Exception("darknet request a pretrained path. got [{}]".format(pretrained))
return model

然后再实现yolov4的框架,包括上图中的SPP和PANet,最后输出三个特征层Yolo head。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from collections import OrderedDict

import torch
import torch.nn as nn

from nets.CSPdarknet import darknet53


def conv2d(filter_in, filter_out, kernel_size, stride=1):
pad = (kernel_size - 1) // 2 if kernel_size else 0
return nn.Sequential(OrderedDict([
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
("bn", nn.BatchNorm2d(filter_out)),
("relu", nn.LeakyReLU(0.1)),
]))


# SPP结构
class SpatialPyramidPooling(nn.Module):
def __init__(self, pool_sizes=[5, 9, 13]):
super(SpatialPyramidPooling, self).__init__()

self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size // 2) for pool_size in pool_sizes])

def forward(self, x):
features = [maxpool(x) for maxpool in self.maxpools[::-1]]
features = torch.cat(features + [x], dim=1)

return features


# 卷积 + 上采样
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()

self.upsample = nn.Sequential(
conv2d(in_channels, out_channels, 1),
nn.Upsample(scale_factor=2, mode='nearest')
)

def forward(self, x, ):
x = self.upsample(x)
return x


# 三次卷积块
def make_three_conv(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
)
return m


# 五次卷积块
def make_five_conv(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
)
return m


# yolov4的输出
def yolo_head(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 3),
nn.Conv2d(filters_list[0], filters_list[1], 1),
)
return m


# yolo_body
class YoloBody(nn.Module):
def __init__(self, num_anchors, num_classes):
super(YoloBody, self).__init__()
# ---------------------------------------------------#
# 生成CSPdarknet53的主干模型
# 获得三个有效特征层,他们的shape分别是:
# 52,52,256
# 26,26,512
# 13,13,1024
# ---------------------------------------------------#
self.backbone = darknet53(None)

self.conv1 = make_three_conv([512, 1024], 1024)
self.SPP = SpatialPyramidPooling()
self.conv2 = make_three_conv([512, 1024], 2048)

self.upsample1 = Upsample(512, 256)
self.conv_for_P4 = conv2d(512, 256, 1)
self.make_five_conv1 = make_five_conv([256, 512], 512)

self.upsample2 = Upsample(256, 128)
self.conv_for_P3 = conv2d(256, 128, 1)
self.make_five_conv2 = make_five_conv([128, 256], 256)

# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
final_out_filter2 = num_anchors * (5 + num_classes)
self.yolo_head3 = yolo_head([256, final_out_filter2], 128)

self.down_sample1 = conv2d(128, 256, 3, stride=2)
self.make_five_conv3 = make_five_conv([256, 512], 512)

# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
final_out_filter1 = num_anchors * (5 + num_classes)
self.yolo_head2 = yolo_head([512, final_out_filter1], 256)

self.down_sample2 = conv2d(256, 512, 3, stride=2)
self.make_five_conv4 = make_five_conv([512, 1024], 1024)

# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
final_out_filter0 = num_anchors * (5 + num_classes)
self.yolo_head1 = yolo_head([1024, final_out_filter0], 512)

def forward(self, x):
# backbone
x2, x1, x0 = self.backbone(x)

# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048
P5 = self.conv1(x0)
P5 = self.SPP(P5)
# 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
P5 = self.conv2(P5)

# 13,13,512 -> 13,13,256 -> 26,26,256
P5_upsample = self.upsample1(P5)
# 26,26,512 -> 26,26,256
P4 = self.conv_for_P4(x1)
# 26,26,256 + 26,26,256 -> 26,26,512
P4 = torch.cat([P4, P5_upsample], axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
P4 = self.make_five_conv1(P4)

# 26,26,256 -> 26,26,128 -> 52,52,128
P4_upsample = self.upsample2(P4)
# 52,52,256 -> 52,52,128
P3 = self.conv_for_P3(x2)
# 52,52,128 + 52,52,128 -> 52,52,256
P3 = torch.cat([P3, P4_upsample], axis=1)
# 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128
P3 = self.make_five_conv2(P3)

# 52,52,128 -> 26,26,256
P3_downsample = self.down_sample1(P3)
# 26,26,256 + 26,26,256 -> 26,26,512
P4 = torch.cat([P3_downsample, P4], axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
P4 = self.make_five_conv3(P4)

# 26,26,256 -> 13,13,512
P4_downsample = self.down_sample2(P4)
# 13,13,512 + 13,13,512 -> 13,13,1024
P5 = torch.cat([P4_downsample, P5], axis=1)
# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512
P5 = self.make_five_conv4(P5)

# ---------------------------------------------------#
# 第三个特征层
# y3=(batch_size,75,52,52)
# ---------------------------------------------------#
out2 = self.yolo_head3(P3)
# ---------------------------------------------------#
# 第二个特征层
# y2=(batch_size,75,26,26)
# ---------------------------------------------------#
out1 = self.yolo_head2(P4)
# ---------------------------------------------------#
# 第一个特征层
# y1=(batch_size,75,13,13)
# ---------------------------------------------------#
out0 = self.yolo_head1(P5)

return out0, out1, out2

搭建完成之后就是训练了。

训练采用的是实验室的TITAN RTX,这里我们小组总共找了将近1900张图片,然后手动打标签,不得不说那一周看虫子都快看吐了。

91pvg.jpg

91e7I.jpg

最后通过脚本扫描所有的图片,将对应的xml标记文件对应。

91rLD.jpg

然后开始训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def start_train(queue=None):
global model, __iterationCount
__iterationCount = 1

# classes和anchor的路径
anchors_path = 'model_data/yolo_anchors.txt'
classes_path = 'model_data/all_classes.txt'
# 获取classes和anchor
class_names = get_classes(classes_path)
anchors = get_anchors(anchors_path)
num_classes = len(class_names)

# 创建yolo模型
model = YoloBody(len(anchors[0]), num_classes)

model_path = "model_data/yolo4_weights.pth"
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if os.path.exists(model_path):
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print('Finished!')
else:
print("error no pretrained model")

net = model.train()

if Cuda:
net = torch.nn.DataParallel(model)
cudnn.benchmark = True
net = net.cuda()

# 建立loss函数
yolo_losses = []
for i in range(3):
yolo_losses.append(YOLOLoss(np.reshape(anchors, [-1, 2]), num_classes,
(input_shape[1], input_shape[0]), smooth_label, Cuda, normalize))

# 获得图片路径和标签
annotation_path = '2007_train.txt'
val_split = 0.1
with open(annotation_path, encoding='utf-8') as f:
lines = f.readlines()
np.random.seed(10101)
np.random.shuffle(lines)
np.random.seed(None)
num_val = int(len(lines) * val_split)
num_train = len(lines) - num_val

Init_Epoch = 0
Freeze_Epoch = 200
Unfreeze_Epoch = 300

# 冻结训练部分
lr = 1e-3
Batch_size = 4

global optimizer
optimizer = optim.Adam(net.parameters(), lr)
if Cosine_lr:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)

if Use_Data_Loader:
train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic,
is_train=True)
val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False)
gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
else:
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(train=True, mosaic=mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(train=False, mosaic=mosaic)

epoch_size = max(1, num_train // Batch_size)
epoch_size_val = num_val // Batch_size
# 冻结训练
for param in model.backbone.parameters():
param.requires_grad = False

for epoch in range(Init_Epoch, Freeze_Epoch):
fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, gen_val, Freeze_Epoch, Cuda,
queue=queue)
lr_scheduler.step()

# 解冻训练
lr = 1e-4
Batch_size = 2

optimizer = optim.Adam(net.parameters(), lr)
if Cosine_lr:
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)

if Use_Data_Loader:
train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic,
is_train=True)
val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False)
gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
else:
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(train=True, mosaic=mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(train=False, mosaic=mosaic)

epoch_size = max(1, num_train // Batch_size)
epoch_size_val = num_val // Batch_size
# 解冻后训练
for param in model.backbone.parameters():
param.requires_grad = True

for epoch in range(Freeze_Epoch, Unfreeze_Epoch):
fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, gen_val, Unfreeze_Epoch, Cuda,
queue=queue)
lr_scheduler.step()

最后训练出来的效果还是很不错的。

91y9b.png

91vPP.png

YoloV5网络

前面说到了由于YoloV4过于庞大,无法放到安卓app中,于是我们小组还开发了YoloV5的版本。

iOyPCc.png

模型也是使用pytorch实现。这里就不再赘述。

训练过程基本和上面差不多,使用相同的训练集,训练完成后获取模型即可。

比较值得记录的是如何放到安卓app上。

Android APP

安卓客户端基础功能采用vue-cli开发,制作为h5小程序,由于pytorch在安卓端必须采用Android studio引入,所以不能直接使用vue打包成apk。

因此我们小组采用capacitorJS将vue app项目转换成原生代码项目,再使用pytorch原生库加载离线识别模型文件,使用NanoHTTPD实现在本地模拟在线api,采用sqlite和okhttpClient定期同步远程数据库。

在安卓客户端拍图片之后,将会先被编码为base64,然后再上传至YoloV4网络,获取结果之后再显示。

在build.gradle中引入pytorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
dependencies {
implementation fileTree(include: ['*.jar'], dir: 'libs')
implementation "androidx.appcompat:appcompat:$androidxAppCompatVersion"
implementation project(':capacitor-android')
testImplementation "junit:junit:$junitVersion"
androidTestImplementation "androidx.test.ext:junit:$androidxJunitVersion"
androidTestImplementation "androidx.test.espresso:espresso-core:$androidxEspressoCoreVersion"
implementation project(':capacitor-cordova-android-plugins')
implementation'org.nanohttpd:nanohttpd:2.3.1'
implementation'com.alibaba:fastjson:1.1.72.android'
implementation'com.squareup.okhttp3:okhttp:3.10.0'
implementation 'org.pytorch:pytorch_android:1.8.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0'
}

然后把训练好的YoloV5模型放入assets文件夹中,在app打开后进行加载。

1
2
3
4
5
6
7
8
/**
* 加载网络模型
* @param assetManager assets读取工具
* @param moduleName Torchscript文件名
*/
public void init(AssetManager assetManager, String moduleName) {
module = PyTorchAndroid.loadModuleFromAsset(assetManager, moduleName);
}

然后实现predict函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
* 输入图片,输出预测信息
* @param bitmap 输入图片
* @param processedImg 处理后的图片,Object[1] 需要放到数组中,否则局部变量无法获取
* @return 预测信息
*/
public Map<String, Integer> predict(Bitmap bitmap, Object[] processedImg) {
HashMap<String, Integer> resultMap = new HashMap<>();
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
new float[]{0.0f, 0.0f, 0.0f}, new float[]{1.0f, 1.0f, 1.0f});

// running the model
IValue value = IValue.from(inputTensor);
final IValue[] output = module.forward(value).toTuple();
final Tensor outputTensor = output[0].toTensor();
final float[] outputs = outputTensor.getDataAsFloatArray();
float mImgScaleX = bitmap.getWidth() / imageSize;
float mImgScaleY = bitmap.getHeight() / imageSize;
float mIvScaleX = (bitmap.getWidth() > bitmap.getHeight() ? imageSize / bitmap.getWidth() : imageSize / bitmap.getHeight());
float mIvScaleY = (bitmap.getHeight() > bitmap.getWidth() ? imageSize / bitmap.getHeight() : imageSize / bitmap.getWidth());
final ArrayList<Result> results = outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY);
boolean hasRect = false;
for (Result each : results) {
String code = Classes.classes[each.classIndex];
if (resultMap.containsKey(code)) {
Integer num = resultMap.get(code);
num += 1;
resultMap.put(code, num);
} else {
resultMap.put(code, 1);
}
if (each.rect.top != each.rect.bottom && each.rect.left != each.rect.right) {
bitmap = drawRectangles(bitmap, each.rect);
processedImg[0] = bitmap;
hasRect = true;
}
if (!hasRect) {
processedImg[0] = bitmap;
}
}
if (resultMap.isEmpty()) {
resultMap.put("error", 1);
}
System.out.println(resultMap);
return resultMap;
}

以及在识别后的图片上画框的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/**
* 在bitmap上画框
* @param imageBitmap 输入图片
* @param valueRects 框的位置
* @return 画了框的图片
*/
private Bitmap drawRectangles(Bitmap imageBitmap, Rect valueRects) {
Bitmap mutableBitmap = imageBitmap.copy(Bitmap.Config.ARGB_8888, true);
Canvas canvas = new Canvas(mutableBitmap);
Paint paint = new Paint();
for (int i = 0; i < 8; i++) {
paint.setColor(Color.RED);
paint.setStyle(Paint.Style.STROKE);//不填充
paint.setStrokeWidth(10); //线的宽度
canvas.drawRect(valueRects.left, valueRects.top, valueRects.right, valueRects.bottom, paint);
}
return mutableBitmap;
}

以及nms函数,和非极大抑制函数,计算IOU的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/**
* 对原始输出进行处理
* @param outputs 输出
* @param imgScaleX 相对图片大小X
* @param imgScaleY 相对图片大小Y
* @param ivScaleX 显示最小图片大小X
* @param ivScaleY 显示最小图片大小Y
* @return Result list
*/
private static ArrayList<Result> outputsToNMSPredictions(float[] outputs, float imgScaleX, float imgScaleY, float ivScaleX, float ivScaleY) {
ArrayList<Result> results = new ArrayList<>();
int mOutputRow = 25200;
int mOutputColumn = Classes.classes.length + 5;
for (int i = 0; i < mOutputRow; i++) {
try {
if (outputs[i * mOutputColumn + 4] > mThreshold) {
float x = outputs[i * mOutputColumn];
float y = outputs[i * mOutputColumn + 1];
float w = outputs[i * mOutputColumn + 2];
float h = outputs[i * mOutputColumn + 3];

float left = imgScaleX * (x - w / 2);
float top = imgScaleY * (y - h / 2);
float right = imgScaleX * (x + w / 2);
float bottom = imgScaleY * (y + h / 2);

float max = outputs[i * mOutputColumn + 5];
int cls = 0;
for (int j = 0; j < mOutputColumn - 5; j++) {
if (outputs[i * mOutputColumn + 5 + j] > max) {
max = outputs[i * mOutputColumn + 5 + j];
cls = j;
}
}

Rect rect = new Rect((int) ((float) 0 + ivScaleX * left), (int) ((float) 0 + top * ivScaleY), (int) ((float) 0 + ivScaleX * right), (int) ((float) 0 + ivScaleY * bottom));
Result result = new Result(cls, outputs[i * mOutputColumn + 4], rect);
results.add(result);
}
} catch (ArrayIndexOutOfBoundsException e) {
e.printStackTrace();
}
}
return nonMaxSuppression(results);
}

/**
* 非极大抑制
* @param boxes 结果
* @return 筛选后的结果
*/
private static ArrayList<Result> nonMaxSuppression(ArrayList<Result> boxes) {

Collections.sort(boxes,
(o1, o2) -> o1.score.compareTo(o2.score));

ArrayList<Result> selected = new ArrayList<>();
boolean[] active = new boolean[boxes.size()];
Arrays.fill(active, true);
int numActive = active.length;

boolean done = false;
for (int i = 0; i < boxes.size() && !done; i++) {
if (active[i]) {
Result boxA = boxes.get(i);
selected.add(boxA);
if (selected.size() >= CnnApi.mNmsLimit) break;

for (int j = i + 1; j < boxes.size(); j++) {
if (active[j]) {
Result boxB = boxes.get(j);
if (IOU(boxA.rect, boxB.rect) > CnnApi.mThreshold) {
active[j] = false;
numActive -= 1;
if (numActive <= 0) {
done = true;
break;
}
}
}
}
}
}
return selected;
}

private static float IOU(Rect a, Rect b) {
float areaA = (a.right - a.left) * (a.bottom - a.top);
if (areaA <= 0.0) return 0.0f;

float areaB = (b.right - b.left) * (b.bottom - b.top);
if (areaB <= 0.0) return 0.0f;

float intersectionMinX = Math.max(a.left, b.left);
float intersectionMinY = Math.max(a.top, b.top);
float intersectionMaxX = Math.min(a.right, b.right);
float intersectionMaxY = Math.min(a.bottom, b.bottom);
float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0) *
Math.max(intersectionMaxX - intersectionMinX, 0);
return intersectionArea / (areaA + areaB - intersectionArea);
}

static class Result {
int classIndex;
Float score;
Rect rect;

public Result(int cls, Float output, Rect rect) {
this.classIndex = cls;
this.score = output;
this.rect = rect;
}
}

然后将这个离线识别方法封装成API。

采用NanoHTTPD对外提供接口。这样在无网络的环境下只需要把服务器的ip地址更改为127.0.0.1,即可请求到本地的YoloV5识别网络,并实现离线识别。

最后放几张图片

91tNK.jpg

914Ea.jpg