Stanford CS149 (Fall 2023) Parallel Computing

Programming Assignment 4: Chat149 - A Flash Attention Transformer DNN

Github repo

环境配置

妈的,外校学生又没有服务器,只能自己配环境,又不给 requirements.txt,装个 torch 因为版本问题整了半天,翻 issue 才知道要装 2.1.2 版本,至于其它 module 就一直编译靠报错找依赖(

Warm-Up: Accessing Tensors (3 Points)

简单来说就是实现 4D-Tensor(类似四维数组)和一维数组之间的转换,秒

(然而其实这个实现性能上还稍微有点问题,在 Part2 才查出来)

1
2
3
4
5
6
7
8
9
10
11
inline float fourDimRead(std::vector<float> &tensor, int &x, int &y, int &z,
int &b, const int &sizeX, const int &sizeY,
const int &sizeZ) {
return tensor[((x * sizeX + y) * sizeY + z) * sizeZ + b];
}

inline void fourDimWrite(std::vector<float> &tensor, int &x, int &y, int &z,
int &b, const int &sizeX, const int &sizeY,
const int &sizeZ, float &val) {
tensor[((x * sizeX + y) * sizeY + z) * sizeZ + b] = val;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
~/codes/CS149/cs149gpt (master*) » python3 gpt149.py 4Daccess                                                                                          mizukicry@S-Terminal

Compiling code into a PyTorch module...



Tensor Shape: torch.Size([1, 2, 4, 4])

4D Tensor Contents:
tensor([[[[0.0000e+00, 1.0000e-04, 2.0000e-04, 3.0000e-04],
[2.0000e-04, 3.0000e-04, 4.0000e-04, 5.0000e-04],
[4.0000e-04, 5.0000e-04, 6.0000e-04, 7.0000e-04],
[6.0000e-04, 7.0000e-04, 8.0000e-04, 9.0000e-04]],

[[0.0000e+00, 1.0000e-04, 2.0000e-04, 3.0000e-04],
[2.0000e-04, 3.0000e-04, 4.0000e-04, 5.0000e-04],
[4.0000e-04, 5.0000e-04, 6.0000e-04, 7.0000e-04],
[6.0000e-04, 7.0000e-04, 8.0000e-04, 9.0000e-04]]]])

Indexing Value When: x = 0, y = 0, z = 3, b = 0
Expected: 0.0006
Result: 0.0006

Part 1: A Simple (But Not So Efficient) Implementation of Attention (10 Points)

实现最基础的 Attention Module

总的就是三个步骤:矩阵乘法、Softmax、矩阵乘法。照着写就行了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// -------- YOUR CODE HERE  -------- //

for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
// QK_t = Q * K^t
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float sum = 0.0;
for (int k = 0; k < d; k++) {
sum += fourDimRead(Q, b, h, i, k, H, N, d) *
fourDimRead(K, b, h, j, k, H, N, d);
}
twoDimWrite(QK_t, i, j, N, sum);
}
}

// softmax(QK_t)
for (int i = 0; i < N; i++) {
float sum = 0.0;
for (int j = 0; j < N; j++) {
sum += std::exp(twoDimRead(QK_t, i, j, N));
}
for (int j = 0; j < N; j++) {
float val = std::exp(twoDimRead(QK_t, i, j, N)) / sum;
twoDimWrite(QK_t, i, j, N, val);
}
}

// O = QK_t * V
for (int i = 0; i < N; i++) {
for (int j = 0; j < d; j++) {
float sum = 0.0;
for (int k = 0; k < N; k++) {
sum +=
twoDimRead(QK_t, i, k, N) * fourDimRead(V, b, h, k, j, H, N, d);
}
fourDimWrite(O, b, h, i, j, H, N, d, sum);
}
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
~/codes/CS149/cs149gpt (master*) » python3 gpt149.py part1                                                                                             mizukicry@S-Terminal

Compiling code into a PyTorch module...


Running Part 1 Test: Naive Unfused Attention

-----RUNNING REFERENCE IMPLEMENTATION-----

WARNING:2024-04-20 13:56:02 194936:194936 init.cpp:155] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_NOT_INITIALIZED (15)
WARNING:2024-04-20 13:56:02 194936:194936 init.cpp:156] CUPTI initialization failed - CUDA profiler activities will be missing
INFO:2024-04-20 13:56:02 194936:194936 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
STAGE:2024-04-20 13:56:02 194936:194936 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-20 13:56:02 194936:194936 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-20 13:56:02 194936:194936 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.24445605278015137

------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.05% 134.000us 0.05% 134.000us 44.667us 5.00 Mb 5.00 Mb 3
REFERENCE - NAIVE ATTENTION 91.28% 224.579ms 99.21% 244.096ms 244.096ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.64% 1.584ms 4.24% 10.432ms 5.216ms 4.50 Mb 0 b 2
aten::clone 0.56% 1.382ms 3.19% 7.857ms 3.929ms 1.00 Mb 0 b 2
model_inference 0.79% 1.947ms 100.00% 246.043ms 246.043ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.48% 1.175ms 3.20% 7.863ms 1.573ms 512.00 Kb 0 b 5
aten::empty_like 0.06% 137.000us 0.07% 168.000us 168.000us 512.00 Kb 0 b 1
aten::empty_strided 0.06% 139.000us 0.06% 139.000us 139.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.89% 2.186ms 3.55% 8.745ms 4.372ms 0 b 0 b 2
aten::fill_ 2.67% 6.559ms 2.67% 6.559ms 3.280ms 0 b 0 b 2
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 246.043ms

REFERENCE - NAIVE ATTENTION statistics
cpu time: 244.096ms
mem usage: 4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

STAGE:2024-04-20 13:56:08 194936:194936 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-20 13:56:08 194936:194936 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-20 13:56:08 194936:194936 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.20301079750061035

----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 22.000us 0.01% 22.000us 7.333us 5.00 Mb 5.00 Mb 3
STUDENT - NAIVE ATTENTION 99.40% 201.873ms 99.93% 202.955ms 202.955ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.02% 45.000us 0.17% 342.000us 171.000us 4.50 Mb 0 b 2
aten::clone 0.06% 118.000us 0.32% 654.000us 327.000us 1.00 Mb 0 b 2
model_inference 0.07% 133.000us 100.00% 203.088ms 203.088ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.02% 37.000us 0.21% 428.000us 85.600us 512.00 Kb 0 b 5
aten::empty_like 0.00% 6.000us 0.00% 10.000us 10.000us 512.00 Kb 0 b 1
aten::empty_strided 0.01% 12.000us 0.01% 12.000us 12.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.03% 54.000us 0.14% 279.000us 139.500us 0 b 0 b 2
aten::fill_ 0.11% 225.000us 0.11% 225.000us 112.500us 0 b 0 b 2
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 203.088ms

STUDENT - NAIVE ATTENTION statistics
cpu time: 202.955ms
mem usage: 4718592 bytes

Part 2: Blocked Matrix Multiply and Unfused Softmax (20 Points)

使用课件中提到的分块矩阵乘法提高 cache locality

一开始实现出来莫名比 reference 慢不少,试了试发现是前面 4D-Tensor 访问的函数有点问题,一开始我自作聪明给优化了(减少乘法计算),结果反而不利于编译器优化,改成直接硬算快了不少

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Step #2: Implement Read/Write Accessors for a 4D Tensor
inline float fourDimRead(std::vector<float> &tensor, int &x, int &y, int &z,
int &b, const int &sizeX, const int &sizeY,
const int &sizeZ) {
// return tensor[((x * sizeX + y) * sizeY + z) * sizeZ + b];
return tensor[x * sizeX * sizeY * sizeZ + y * sizeY * sizeZ + z * sizeZ + b];
}

inline void fourDimWrite(std::vector<float> &tensor, int &x, int &y, int &z,
int &b, const int &sizeX, const int &sizeY,
const int &sizeZ, float &val) {
// tensor[((x * sizeX + y) * sizeY + z) * sizeZ + b] = val;
tensor[x * sizeX * sizeY * sizeZ + y * sizeY * sizeZ + z * sizeZ + b] = val;
}

然后是关于块大小的问题,理论上让块的每行正好填满一个 cache line 最好,在我的机子上通过 cat /sys/devices/system/cpu/cpu1/cache/index0/coherency_line_size 查询得到 cache line size 是 64 bytes,所以我就直接用 16 作为块大小了(后来才看见 PA 里有说)

测试的时候我一开始在图书馆跑,笔记本开的静音模式,比 reference 差 20ms 以上,换平衡模式就好了,测试程序性能也是个深奥的内容(

后来把循环部分给优化了一下,比如 for (int i = b_i; i < b_i + L && i < N; i++) 这种,先预处理出 min(b_i + L, N),然后直接用这个值代替循环条件,效果还挺明显

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// -------- YOUR CODE HERE  -------- //

constexpr int L = 16;

for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
// QK_t = Q * K^t
for (int b_i = 0; b_i < N; b_i += L) {
for (int b_j = 0; b_j < N; b_j += L) {
for (int b_k = 0; b_k < d; b_k += L) {
int m_i = std::min(N, b_i + L);
int m_j = std::min(N, b_j + L);
int m_k = std::min(d, b_k + L);
for (int i = b_i; i < m_i; i++) {
for (int j = b_j; j < m_j; j++) {
float sum = twoDimRead(QK_t, i, j, N);
for (int k = b_k; k < m_k; k++) {
sum += fourDimRead(Q, b, h, i, k, H, N, d) *
fourDimRead(K, b, h, j, k, H, N, d);
}
twoDimWrite(QK_t, i, j, N, sum);
}
}
}
}
}

// softmax(QK_t)
for (int i = 0; i < N; i++) {
float sum = 0.0;
for (int j = 0; j < N; j++) {
sum += std::exp(twoDimRead(QK_t, i, j, N));
}
for (int j = 0; j < N; j++) {
float val = std::exp(twoDimRead(QK_t, i, j, N)) / sum;
twoDimWrite(QK_t, i, j, N, val);
}
}

// O = QK_t * V
for (int b_i = 0; b_i < N; b_i += L) {
for (int b_j = 0; b_j < d; b_j += L) {
for (int b_k = 0; b_k < N; b_k += L) {
int m_i = std::min(N, b_i + L);
int m_j = std::min(d, b_j + L);
int m_k = std::min(N, b_k + L);
for (int i = b_i; i < m_i; i++) {
for (int j = b_j; j < m_j; j++) {
float sum = fourDimRead(O, b, h, i, j, H, N, d);
for (int k = b_k; k < m_k; k++) {
sum += twoDimRead(QK_t, i, k, N) *
fourDimRead(V, b, h, k, j, H, N, d);
}
fourDimWrite(O, b, h, i, j, H, N, d, sum);
}
}
}
}
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
~/codes/CS149/cs149gpt (master) » python3 gpt149.py part2                                                                                              mizukicry@S-Terminal

Compiling code into a PyTorch module...


Running Part 2 Test: Unfused Attention with Blocked Matmul

-----RUNNING REFERENCE IMPLEMENTATION-----

WARNING:2024-04-20 16:28:08 262231:262231 init.cpp:155] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_NOT_INITIALIZED (15)
WARNING:2024-04-20 16:28:08 262231:262231 init.cpp:156] CUPTI initialization failed - CUDA profiler activities will be missing
INFO:2024-04-20 16:28:08 262231:262231 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
STAGE:2024-04-20 16:28:08 262231:262231 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-20 16:28:08 262231:262231 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-20 16:28:08 262231:262231 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.12445354461669922

------------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.02% 29.000us 0.02% 29.000us 9.667us 5.00 Mb 5.00 Mb 3
REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX 96.17% 119.717ms 99.93% 124.403ms 124.403ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.02% 30.000us 0.47% 584.000us 292.000us 4.50 Mb 0 b 2
aten::clone 0.04% 44.000us 3.22% 4.010ms 2.005ms 1.00 Mb 0 b 2
model_inference 0.07% 81.000us 100.00% 124.484ms 124.484ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.06% 72.000us 0.91% 1.139ms 227.800us 512.00 Kb 0 b 5
aten::empty_like 0.01% 7.000us 0.01% 13.000us 13.000us 512.00 Kb 0 b 1
aten::empty_strided 0.02% 28.000us 0.02% 28.000us 28.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.02% 27.000us 0.43% 531.000us 265.500us 0 b 0 b 2
aten::fill_ 0.40% 504.000us 0.40% 504.000us 252.000us 0 b 0 b 2
------------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 124.484ms

REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 124.403ms
mem usage: 4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

STAGE:2024-04-20 16:28:13 262231:262231 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-20 16:28:13 262231:262231 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-20 16:28:13 262231:262231 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.10607481002807617

---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 14.000us 0.01% 14.000us 4.667us 5.00 Mb 5.00 Mb 3
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX 99.27% 105.336ms 99.94% 106.045ms 106.045ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.01% 14.000us 0.20% 208.000us 104.000us 4.50 Mb 0 b 2
aten::clone 0.03% 27.000us 0.44% 462.000us 231.000us 1.00 Mb 0 b 2
model_inference 0.06% 61.000us 100.00% 106.106ms 106.106ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.02% 24.000us 0.26% 276.000us 55.200us 512.00 Kb 0 b 5
aten::empty_like 0.00% 3.000us 0.00% 4.000us 4.000us 512.00 Kb 0 b 1
aten::empty_strided 0.01% 10.000us 0.01% 10.000us 10.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.01% 8.000us 0.17% 181.000us 90.500us 0 b 0 b 2
aten::fill_ 0.16% 173.000us 0.16% 173.000us 86.500us 0 b 0 b 2
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 106.106ms

STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 106.045ms
mem usage: 4718592 bytes

Part 3: Fused Attention (25 Points)

使用 OpenMP 并行化

通过对矩阵的每一行分别进行操作,可以使用多线程加速,只需要加一行 #pragma 就行,非常方便

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// -------- YOUR CODE HERE  -------- //
// We give you a template of the first three loops for your convenience
// loop over batch
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++) {

// loop over heads
for (int h = 0; h < H; h++) {
for (int i = 0; i < N; i++) {

// YRow is moved inside so each OpenMP thread gets a local copy.
at::Tensor ORowTensor = temp.index({torch::indexing::Slice(
omp_get_thread_num(), torch::indexing::None)});
std::vector<float> ORow = formatTensor(ORowTensor);
// YOUR CODE HERE

// QK_t = Q * K^t
for (int j = 0; j < N; j++) {
float sum = 0.0;
for (int k = 0; k < d; k++) {
sum += fourDimRead(Q, b, h, i, k, H, N, d) *
fourDimRead(K, b, h, j, k, H, N, d);
}
ORow[j] = sum;
}

// softmax(QK_t)
float sum = 0.0;
for (int j = 0; j < N; j++) {
ORow[j] = std::exp(ORow[j]);
sum += ORow[j];
}
for (int j = 0; j < N; j++) {
ORow[j] /= sum;
}

// O = QK_t * V

for (int j = 0; j < d; j++) {
float sum = 0.0;
for (int k = 0; k < N; k++) {
sum += ORow[k] * fourDimRead(V, b, h, k, j, H, N, d);
}
fourDimWrite(O, b, h, i, j, H, N, d, sum);
}
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
~/codes/CS149/cs149gpt (master*) » python3 gpt149.py part3                                         mizukicry@S-Terminal

Compiling code into a PyTorch module...


Running Part 3 Test: Fused Attention

-----RUNNING REFERENCE IMPLEMENTATION-----

WARNING:2024-04-20 22:41:40 8263:8263 init.cpp:155] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_NOT_INITIALIZED (15)
WARNING:2024-04-20 22:41:40 8263:8263 init.cpp:156] CUPTI initialization failed - CUDA profiler activities will be missing
INFO:2024-04-20 22:41:40 8263:8263 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
STAGE:2024-04-20 22:41:40 8263:8263 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-20 22:41:40 8263:8263 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-20 22:41:40 8263:8263 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.03698134422302246

------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.10% 38.000us 0.10% 38.000us 12.667us 1.03 Mb 1.03 Mb 3
aten::clone 0.12% 46.000us 1.06% 391.000us 195.500us 1.00 Mb 0 b 2
REFERENCE - FUSED ATTENTION 91.59% 33.900ms 99.79% 36.935ms 36.935ms 544.00 Kb -1.00 Mb 1
aten::zeros 0.11% 39.000us 0.98% 362.000us 181.000us 544.00 Kb 0 b 2
model_inference 0.21% 79.000us 100.00% 37.014ms 37.014ms 512.00 Kb -32.00 Kb 1
aten::flatten 2.66% 986.000us 3.96% 1.465ms 2.839us 512.00 Kb 0 b 516
aten::empty_like 0.04% 15.000us 0.08% 29.000us 29.000us 512.00 Kb 0 b 1
aten::empty_strided 0.04% 13.000us 0.04% 13.000us 13.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.10% 36.000us 0.81% 299.000us 149.500us 0 b 0 b 2
aten::fill_ 0.71% 263.000us 0.71% 263.000us 263.000us 0 b 0 b 1
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 37.014ms

REFERENCE - FUSED ATTENTION statistics
cpu time: 36.935ms
mem usage: 557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

STAGE:2024-04-20 22:41:45 8263:8263 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-20 22:41:45 8263:8263 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-20 22:41:45 8263:8263 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.032499074935913086

----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.12% 38.000us 0.12% 38.000us 9.500us 1.04 Mb 1.04 Mb 4
aten::clone 0.08% 26.000us 1.33% 432.000us 216.000us 1.00 Mb 0 b 2
aten::zeros 0.07% 23.000us 0.45% 148.000us 49.333us 548.00 Kb 0 b 3
STUDENT - FUSED ATTENTION 93.30% 30.352ms 99.80% 32.466ms 32.466ms 544.00 Kb -1.00 Mb 1
model_inference 0.20% 65.000us 100.00% 32.531ms 32.531ms 512.00 Kb -32.00 Kb 1
aten::flatten 1.68% 545.000us 2.61% 848.000us 1.640us 512.00 Kb 0 b 517
aten::empty_like 0.01% 4.000us 0.04% 14.000us 14.000us 512.00 Kb 0 b 1
aten::empty_strided 0.06% 20.000us 0.06% 20.000us 20.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.04% 12.000us 0.30% 97.000us 32.333us 0 b 0 b 3
aten::fill_ 0.26% 85.000us 0.26% 85.000us 85.000us 0 b 0 b 1
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 32.531ms

STUDENT - FUSED ATTENTION statistics
cpu time: 32.466ms
mem usage: 557056 bytes

Part 4 : Putting it all Together - Flash Attention (35 Points)

合并 Part 2 的矩阵分块与 Part 3 的 OpenMP 并行化

PA 中已经给出了算法伪代码,照着翻译就是了

话说他这伪代码也挺迷惑,0-indexed 和 1-indexed 都没分清。Tensor 的访问函数参数竟然不全是 const&,给改了一下,不然没法传递临时变量

另外一提,伪代码中忽略了外面两层(Batch Size, Number of Heads)的循环,写的时候要记得每一次的初始化,最简单的是直接移到循环里面来定义

Part 4 中不需要考虑性能,主要看是否正确,所以我就没做什么优化了,不过只要照着写出来就很快了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// Format All Tensors into Vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);

// -------- YOUR CODE HERE -------- //
// constexpr int M = 512 /*KB*/ * 1024 / sizeof(float);
// const int Bc = (M + 4 * d - 1) / (4 * d);
// const int Br = std::min(Bc, d);
const int Tr = (N + Br - 1) / Br;
const int Tc = (N + Bc - 1) / Bc;

for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
std::vector<float> Sij = formatTensor(SijTensor);
std::vector<float> Pij = formatTensor(PijTensor);
std::vector<float> Kj = formatTensor(KjTensor);
std::vector<float> Vj = formatTensor(VjTensor);
std::vector<float> Qi = formatTensor(QiTensor);
std::vector<float> Oi = formatTensor(OiTensor);
std::vector<float> l = formatTensor(LTensor);
std::vector<float> PV = formatTensor(PVTensor);
std::vector<float> li = formatTensor(LiTensor);
std::vector<float> lij = formatTensor(LijTensor);
std::vector<float> lnew = formatTensor(LnewTensor);
for (int j = 0; j < Tc; j++) {
// Load K_j, V_j
const int mx_Bc = std::min(Bc, N - j * Bc);
for (int x = 0; x < mx_Bc; x++) {
for (int y = 0; y < d; y++) {
twoDimWrite(Kj, x, y, d,
fourDimRead(K, b, h, j * Bc + x, y, H, N, d));
twoDimWrite(Vj, x, y, d,
fourDimRead(V, b, h, j * Bc + x, y, H, N, d));
}
}

for (int i = 0; i < Tr; i++) {
// Load Q_i, O_i, l_i
const int mx_Br = std::min(Br, N - i * Br);
for (int x = 0; x < mx_Br; x++) {
for (int y = 0; y < d; y++) {
twoDimWrite(Qi, x, y, d,
fourDimRead(Q, b, h, i * Br + x, y, H, N, d));
twoDimWrite(Oi, x, y, d,
fourDimRead(O, b, h, i * Br + x, y, H, N, d));
li[x] = l[i * Br + x];
}
}

// S_ij = Q_i * K_j^t
for (int x = 0; x < mx_Br; x++) {
for (int y = 0; y < mx_Bc; y++) {
float sum = 0.0;
for (int z = 0; z < d; z++) {
sum += twoDimRead(Qi, x, z, d) * twoDimRead(Kj, y, z, d);
}
twoDimWrite(Sij, x, y, Bc, sum);
}
}

// P_ij = exp(S_ij)
for (int x = 0; x < mx_Br; x++) {
for (int y = 0; y < mx_Bc; y++) {
twoDimWrite(Pij, x, y, Bc, std::exp(twoDimRead(Sij, x, y, Bc)));
}
}

// l_ij = rowsum(P_ij)
for (int x = 0; x < mx_Br; x++) {
float sum = 0.0;
for (int y = 0; y < mx_Bc; y++) {
sum += twoDimRead(Pij, x, y, Bc);
}
lij[x] = sum;
}

// l_new = l_i + l_ij
for (int x = 0; x < mx_Br; x++) {
lnew[x] = li[x] + lij[x];
}

// O_i = (l_i * O_i + P_ij * V_j) / l_new
for (int x = 0; x < mx_Br; x++) {
for (int y = 0; y < d; y++) {
float sum = 0.0;
for (int z = 0; z < mx_Bc; z++) {
sum += twoDimRead(Pij, x, z, Bc) * twoDimRead(Vj, z, y, d);
}
twoDimWrite(Oi, x, y, d,
(li[x] * twoDimRead(Oi, x, y, d) + sum) / lnew[x]);
}
}

// Write O_i back to O, l_new back to l
for (int x = 0; x < mx_Br; x++) {
for (int y = 0; y < d; y++) {
fourDimWrite(O, b, h, i * Br + x, y, H, N, d,
twoDimRead(Oi, x, y, d));
}
l[i * Br + x] = lnew[x];
}
}
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
~/codes/CS149/cs149gpt (master*) » python3 gpt149.py part4                                                              mizukicry@S-Terminal

Compiling code into a PyTorch module...


Running Part 4 Test: Flash Attention

-----RUNNING REFERENCE IMPLEMENTATION-----

WARNING:2024-04-21 12:05:12 103099:103099 init.cpp:155] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_NOT_INITIALIZED (15)
WARNING:2024-04-21 12:05:12 103099:103099 init.cpp:156] CUPTI initialization failed - CUDA profiler activities will be missing
INFO:2024-04-21 12:05:12 103099:103099 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
STAGE:2024-04-21 12:05:13 103099:103099 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-21 12:05:13 103099:103099 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-21 12:05:13 103099:103099 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.5386707782745361

------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::zeros 0.02% 93.000us 0.51% 2.733ms 195.214us 9.16 Mb -31.00 Kb 14
aten::empty 0.03% 136.000us 0.03% 136.000us 9.714us 9.16 Mb 9.16 Mb 14
model_inference 0.06% 317.000us 100.00% 538.752ms 538.752ms 512.00 Kb -679.00 Kb 1
REFERENCE - FLASH ATTENTION 94.26% 507.821ms 99.85% 537.959ms 537.959ms 512.00 Kb -8.00 Mb 1
aten::zero_ 0.64% 3.469ms 5.63% 30.323ms 81.954us 32.00 Kb 32.00 Kb 370
aten::fill_ 5.00% 26.916ms 5.00% 26.916ms 202.376us 0 b 0 b 133
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 538.752ms

REFERENCE - FLASH ATTENTION statistics
cpu time: 537.959ms
mem usage: 524288 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

STAGE:2024-04-21 12:05:18 103099:103099 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-04-21 12:05:18 103099:103099 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-04-21 12:05:18 103099:103099 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.16223454475402832

----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 19.000us 0.01% 19.000us 1.462us 1.66 Mb 1.66 Mb 13
aten::zeros 0.02% 29.000us 0.06% 103.000us 8.583us 1.16 Mb -31.00 Kb 12
aten::clone 0.04% 60.000us 0.26% 416.000us 208.000us 1.00 Mb 0 b 2
model_inference 0.06% 94.000us 100.00% 162.258ms 162.258ms 512.00 Kb -684.00 Kb 1
STUDENT - FLASH ATTENTION 99.48% 161.410ms 99.89% 162.082ms 162.082ms 512.00 Kb -1.00 Mb 1
aten::flatten 0.09% 147.000us 0.26% 423.000us 8.812us 512.00 Kb 0 b 48
aten::empty_like 0.00% 8.000us 0.01% 11.000us 11.000us 512.00 Kb 0 b 1
aten::empty_strided 0.01% 10.000us 0.01% 10.000us 10.000us 512.00 Kb 512.00 Kb 1
aten::zero_ 0.01% 12.000us 0.04% 58.000us 4.833us 37.00 Kb 37.00 Kb 12
aten::fill_ 0.03% 46.000us 0.03% 46.000us 15.333us 0 b 0 b 3
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 162.258ms

STUDENT - FLASH ATTENTION statistics
cpu time: 162.082ms
mem usage: 524288 bytes
  • Extra Credit: Optimize Further (12 Total Points - 3 Points Per Part)

也就是用 ISPC 继续优化,这个就不写了,直接跳过