Apollo Cyber Study. (cyber/base 3)

終於完成cyber/base的部份了

// Study: 是我的筆記

cyber/base/reentrant_rw_lock

先定義一下Reentrant rw lock
大體上跟等價於一般的RW lock
主要不同就是
 它保護了當前持有鎖的線程,想再加鎖時不會被blocking
減少上下文切換的成本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/

#ifndef CYBER_BASE_REENTRANT_RW_LOCK_H_
#define CYBER_BASE_REENTRANT_RW_LOCK_H_

#include <stdint.h>
#include <unistd.h>
#include <atomic>
#include <condition_variable>
#include <cstdlib>
#include <iostream>
#include <mutex>
#include <thread>

#include "cyber/base/rw_lock_guard.h"

namespace apollo {
namespace cyber {
namespace base {

// Study: 要做reentrant 就一定要先知道自己thread id
static const std::thread::id NULL_THREAD_ID = std::thread::id();
class ReentrantRWLock {
friend class ReadLockGuard<ReentrantRWLock>;
friend class WriteLockGuard<ReentrantRWLock>;

public:
static const int32_t RW_LOCK_FREE = 0;
// Study: lock num = WRITE_EXCLUSIVE mean can read
static const int32_t WRITE_EXCLUSIVE = -1;
static const uint32_t MAX_RETRY_TIMES = 5;
static const std::thread::id null_thread;
ReentrantRWLock() {}
explicit ReentrantRWLock(bool write_first) : write_first_(write_first) {}

private:
// all these function only can used by ReadLockGuard/WriteLockGuard;
void ReadLock();
void WriteLock();

void ReadUnlock();
void WriteUnlock();

ReentrantRWLock(const ReentrantRWLock&) = delete;
ReentrantRWLock& operator=(const ReentrantRWLock&) = delete;
// Study: Check who getting the write lock
std::thread::id write_thread_id_ = {NULL_THREAD_ID};
std::atomic<uint32_t> write_lock_wait_num_ = {0};
// Study: Allow multiple repeated lock, so need lock num to count the lock
std::atomic<int32_t> lock_num_ = {0};
bool write_first_ = true;
};

inline void ReentrantRWLock::ReadLock() {
// Study: Reentrant Check
if (write_thread_id_ == std::this_thread::get_id()) {
return;
}

uint32_t retry_times = 0;
int32_t lock_num = lock_num_.load(std::memory_order_acquire);
if (write_first_) {
do {
while (lock_num < RW_LOCK_FREE ||
write_lock_wait_num_.load(std::memory_order_acquire) > 0) {
if (++retry_times == MAX_RETRY_TIMES) {
// saving cpu
std::this_thread::yield();
retry_times = 0;
}
lock_num = lock_num_.load(std::memory_order_acquire);
}
} while (!lock_num_.compare_exchange_weak(lock_num, lock_num + 1,
std::memory_order_acq_rel,
std::memory_order_relaxed));
} else {
do {
while (lock_num < RW_LOCK_FREE) {
if (++retry_times == MAX_RETRY_TIMES) {
// saving cpu
std::this_thread::yield();
retry_times = 0;
}
lock_num = lock_num_.load(std::memory_order_acquire);
}
} while (!lock_num_.compare_exchange_weak(lock_num, lock_num + 1,
std::memory_order_acq_rel,
std::memory_order_relaxed));
}
}

inline void ReentrantRWLock::WriteLock() {
auto this_thread_id = std::this_thread::get_id();
// Study: Reentrant Check
if (write_thread_id_ == this_thread_id) {
lock_num_.fetch_sub(1);
return;
}
int32_t rw_lock_free = RW_LOCK_FREE;
uint32_t retry_times = 0;
write_lock_wait_num_.fetch_add(1);
while (!lock_num_.compare_exchange_weak(rw_lock_free, WRITE_EXCLUSIVE,
std::memory_order_acq_rel,
std::memory_order_relaxed)) {
// rw_lock_free will change after CAS fail, so init agin
rw_lock_free = RW_LOCK_FREE;
if (++retry_times == MAX_RETRY_TIMES) {
// saving cpu
std::this_thread::yield();
retry_times = 0;
}
}
write_thread_id_ = this_thread_id;
write_lock_wait_num_.fetch_sub(1);
}

// Study:  比非reentrant version 多了個thread check
inline void ReentrantRWLock::ReadUnlock() {
if (write_thread_id_ == std::this_thread::get_id()) {
return;
}
lock_num_.fetch_sub(1);
}

// Study: lock num的更新要同時維護write_thread_id_
inline void ReentrantRWLock::WriteUnlock() {
if (lock_num_.fetch_add(1) == WRITE_EXCLUSIVE) {
write_thread_id_ = NULL_THREAD_ID;
}
}

} // namespace base
} // namespace cyber
} // namespace apollo

#endif // CYBER_BASE_REENTRANT_RW_LOCK_H_

cyber/base/signal

signal and slot 的 concept可看qt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/

#ifndef CYBER_BASE_SIGNAL_H_
#define CYBER_BASE_SIGNAL_H_

#include <algorithm>
#include <functional>
#include <list>
#include <memory>
#include <mutex>

namespace apollo {
namespace cyber {
namespace base {

// Study: Forward Declaration
template <typename... Args>
class Slot;

template <typename... Args>
class Connection;

template <typename... Args>
class Signal {
public:
using Callback = std::function<void(Args...)>;
using SlotPtr = std::shared_ptr<Slot<Args...>>;
using SlotList = std::list<SlotPtr>;
using ConnectionType = Connection<Args...>;

Signal() {}
virtual ~Signal() { DisconnectAllSlots(); }

// Study: When the signal is activate
void operator()(Args... args) {
// Study: limit the lock scope
SlotList local;
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& slot : slots_) {
local.emplace_back(slot);
}
}

if (!local.empty()) {
for (auto& slot : local) {
(*slot)(args...);
}
}

ClearDisconnectedSlots();
}

// Study: This is the first step in using a signal
// Connect signal to slot
ConnectionType Connect(const Callback& cb) {
auto slot = std::make_shared<Slot<Args...>>(cb);
{
std::lock_guard<std::mutex> lock(mutex_);
slots_.emplace_back(slot);
}

return ConnectionType(slot, this);
}

bool Disconnect(const ConnectionType& conn) {
bool find = false;
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& slot : slots_) {
if (conn.HasSlot(slot)) {
find = true;
slot->Disconnect();
}
}
}

if (find) {
ClearDisconnectedSlots();
}
return find;
}

void DisconnectAllSlots() {
std::lock_guard<std::mutex> lock(mutex_);
for (auto& slot : slots_) {
slot->Disconnect();
}
slots_.clear();
}

private:
Signal(const Signal&) = delete;
Signal& operator=(const Signal&) = delete;

void ClearDisconnectedSlots() {
std::lock_guard<std::mutex> lock(mutex_);
slots_.erase(
std::remove_if(slots_.begin(), slots_.end(),
[](const SlotPtr& slot) { return !slot->connected(); }),
slots_.end());
}

SlotList slots_;
std::mutex mutex_;
};

// Study: This represent the connection status between signal and slot
// Help maintain the real time connectivity
template <typename... Args>
class Connection {
public:
using SlotPtr = std::shared_ptr<Slot<Args...>>;
using SignalPtr = Signal<Args...>*;

Connection() : slot_(nullptr), signal_(nullptr) {}
Connection(const SlotPtr& slot, const SignalPtr& signal)
: slot_(slot), signal_(signal) {}
virtual ~Connection() {
slot_ = nullptr;
signal_ = nullptr;
}

Connection& operator=(const Connection& another) {
if (this != &another) {
this->slot_ = another.slot_;
this->signal_ = another.signal_;
}
return *this;
}

bool HasSlot(const SlotPtr& slot) const {
if (slot != nullptr && slot_ != nullptr) {
return slot_.get() == slot.get();
}
return false;
}

bool IsConnected() const {
if (slot_) {
return slot_->connected();
}
return false;
}

bool Disconnect() {
if (signal_ && slot_) {
return signal_->Disconnect(*this);
}
return false;
}

private:
SlotPtr slot_;
SignalPtr signal_;
};

template <typename... Args>
class Slot {
public:
using Callback = std::function<void(Args...)>;
Slot(const Slot& another)
: cb_(another.cb_), connected_(another.connected_) {}
explicit Slot(const Callback& cb, bool connected = true)
: cb_(cb), connected_(connected) {}
virtual ~Slot() {}

// Study: When the slot have receive signal, do cb
void operator()(Args... args) {
if (connected_ && cb_) {
cb_(args...);
}
}

void Disconnect() { connected_ = false; }
bool connected() const { return connected_; }

private:
Callback cb_;
bool connected_ = true;
};

} // namespace base
} // namespace cyber
} // namespace apollo

#endif // CYBER_BASE_SIGNAL_H_

cyber/base/thread_pool

建在bounded_queue上的一個thread pool,
Task 是一個function 跟對應的argument

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/

#ifndef CYBER_BASE_THREAD_POOL_H_
#define CYBER_BASE_THREAD_POOL_H_

#include <atomic>
#include <functional>
#include <future>
#include <memory>
#include <queue>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>

#include "cyber/base/bounded_queue.h"

namespace apollo {
namespace cyber {
namespace base {

class ThreadPool {
public:
explicit ThreadPool(std::size_t thread_num, std::size_t max_task_num = 1000);

// Study: F is the function, Args is arguments,
// it will return the value that wrapped by future
template <typename F, typename... Args>
auto Enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;

~ThreadPool();

private:
std::vector<std::thread> workers_;
BoundedQueue<std::function<void()>> task_queue_;
std::atomic_bool stop_;
};

inline ThreadPool::ThreadPool(std::size_t threads, std::size_t max_task_num)
: stop_(false) {
if (!task_queue_.Init(max_task_num, new BlockWaitStrategy())) {
throw std::runtime_error("Task queue init failed.");
}
// Study: Thread pool of course have thread worker
for (size_t i = 0; i < threads; ++i) {
workers_.emplace_back([this] {
while (!stop_) {
std::function<void()> task;
if (task_queue_.WaitDequeue(&task)) {
task();
}
}
});
}
}

// before using the return value, you should check value.valid()
template <typename F, typename... Args>
auto ThreadPool::Enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;

auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));

std::future<return_type> res = task->get_future();

// don't allow enqueueing after stopping the pool
if (stop_) {
return std::future<return_type>();
}
task_queue_.Enqueue([task]() { (*task)(); });
return res;
};

// the destructor joins all threads
inline ThreadPool::~ThreadPool() {
if (stop_.exchange(true)) {
return;
}
task_queue_.BreakAllWait();
for (std::thread& worker : workers_) {
worker.join();
}
}

} // namespace base
} // namespace cyber
} // namespace apollo

#endif // CYBER_BASE_THREAD_POOL_H_

cyber/base/thread_safe_queue

就一個queue加上了一個mutex.
而它為了也提供wait queue, 所以也加了condition_variable

cyber/base/unbounded_queue

bounded_queue的 unbounded版。
不過要留意兩者實現完全不一樣, 而且unbounded_queue為了thread safe
它是用linked_list, 而不是dynamic array.
unbounded_queue理論上性能會比bounded_queue低
而且它都用compare_exchange_strong, 應該是所設計上就不是給大量並發的場景的