KV Cache是目前LLM Serving架構中,用來加速推論(inference)的一個重要機制,藉由減少不避要的浪費,進而達成加速的目的 當今天使用者輸入了一段提示詞(prompt),而我們的自迴歸解碼器(autoregressive decoder)在生成回應時,可以概略的將此階段分成2部分,一部分是所謂的預填充階段(prefill phase),另一部分是解碼階段(decode phase) 以下圖為例,我們現在的輸入有4個符元(token),此時可以透過矩陣運算,產生$QK^TV$,之後透過transformer所產生的輸出,來預測下一個符元,這個階段。由於是自迴歸的關係,我們將之前的4個符元,再加上這一個新生成的符元,又再做一次$QK^TV$運算,即可得知下一個成生的符元會是什麼。由於causal masking的關係(每一個符元只跟它之前出現的符元有關,因為自迴歸生成的本質是一種文字接龍),當我們要計算這5個符元的$QK^TV$時,其實前面4個符元的$QK^TV$,已在上一輪就計算過了,我們可以將上一輪計算後的結果,存在KV Cache中(如下圖粉紅色的部分),此時我們只要計算第5個符元跟之前4個符元間的attention即可,一般稱此階段為decode phase。 ![image](https://hackmd.io/_uploads/SJPhor0zxe.png) 那為何會導入Page Attention這樣的概念呢?因為我們並不確定LLM在產生輸出時,到底最後輸出的符元會是幾個,比如以下圖為例,當使用者輸入”Four score and seven years ago our"這個句子時,此時LLM己輸出了"father brought",接著直到句子結束前,預計還會有"forth <eos>"等字詞。由於這樣的不確定性,比較簡單的做法是分配一塊較大的記憶體以利於KV Cache的存放及使用,但這樣的做法對於輸出較短的句子,又會增加所謂的internal fragmenttation,可以說是有一好沒兩好,而且較大塊的記憶體,在某些時候也會引起external fragmentation的問題(或許是alignment之類的問題引起?) ![image](https://hackmd.io/_uploads/Hk7zpYCfel.png) 所以原作者就想到可以利用作業系統中虛擬記憶體(virtual memory)的概念,一段連續的logical memory,以固定大小的block做分割。以下圖為例,若每個block可放置4個字詞,則"Four score and seven years ago our fathers brought forth",可利用3個非連續的block來儲存相關的資訊 ![image](https://hackmd.io/_uploads/ByIGmqCzgx.png) 下面是另一範例,分成step 1, 2, 3 在step 1,我們會將"Four score and seven years ago our"等7個字詞存在logical kv blocks中的block 0及1,之後透過block table去映射至physical kv block,其中的#filled欄位,表示此block目前有幾個字詞儲存在其中,每當我們新增新的字詞時,我們便需更新相對應的數量。在step 2我們會產生fathers這個字詞,並在相對應的physical block 1新增此字詞,並更新block table上#filled欄位。在step 3我們會產生brought這個字詞,但因為之前的block己滿,需分配新的logical及physical block,並更新block table建立新的logical to physical關係 ![image](https://hackmd.io/_uploads/H1aQm50fle.png) 下圖是同時有2個不同的request時,physical kv blocks呈現出來的結果 ![image](https://hackmd.io/_uploads/HyTxYCAMgx.png) 下面這個例子,是2個非常相似的request,它們在前7個字詞都是一樣的,在產生第8個不同的字詞前,我們只需要2個physical block即可儲存相關的kv cache資訊,當產生分歧的第8個字詞時,此時再分配另一個physical kv block即可 ![image](https://hackmd.io/_uploads/BJg9YARzxg.png) --- 在vLLM中,page attention己經被整合進去了,可參考其[原始碼][3]及[說明文件][4] 我們先對一些參數的數值做一些假設,以利我們後續的說明, ```cuda // cuda grid & block setting num_heads = gridDim.x; max_num_partitions = gridDim.z; head_idx = blockIdx.x; seq_idx = blockIdx.y; partition_idx = blockIdx.z; thread_idx = threadIdx.x; // head的大小 HEAD_SIZES = 128 // 每個block會存幾個符元(token) BLOCK_SIZES = 16 // 每個warp中有幾個執行緒,一般在cuda為32,在ROCm中,不同平台有32、64等不同設定 WARP_SIZE = 32 // 每個thread block中的執行緒個數,預設值為128,即4個warps NUM_THREADS = 128 // warp數量 NUM_WARPS = NUM_THREADS / WARP_SIZE = 128 / 32 = 4 // 每個warp中,負責完成一個block所需qkv運算的執行緒個數。比如一個warp有32個執行緒 // ,若一個block有16個符元,那每2個執行緒可處理一個符元所需的計算 THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1) = 32 / 16 = 2 // 若一個thread block有128個thread,且每個thread group size為2,則可算出有64個thread group NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE = 128 / 2 = 64 // 每個thread group負責的token數量,最少為1 NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE) = DIVIDE_ROUND_UP(16 / 32) = 1 // vector size指的是每個thread,在每次讀取資料時的最小單位,它是以elements的個數為單位 // 比如今天的資料型別若為half,16 bytes的資料除以2,可得到一共有8個elements // 若一個thread group有2個thread,則每個thread負責4個elements,vector size即為4 // page attentition的設計中,會希望每次處理一小塊記憶體(即16 bytes), // 而這塊記憶體的elements數,即為程式碼中常見的x參數 VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1) = MAX(16 / (2 * 2), 1) = 4 // 每個thread要負責讀取或計算一個head中的element個數 NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE = 128 / 2 = 64 // 同上,只是換成以vector為單位 NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE = 64 / 4 = 16 ``` page attention的實作除了利用block table將logical memory對應至physical memory外,最重要的是它利用了CUDA的平行運算能力,來加速$QK^TV$的運算 因為自迴歸的關係,我們僅考慮解碼階段(decoding phase),所以q實際上並不是一整個句子,而是單一個一個符元(token),僅計算單一個符元跟之前kv cache的相對關係。由初始cuda grid & block的設定,我們可以知道每個thread block負責一個head的計算,因為在單一個子空間(subspace)之間的點積(dot product),並不會跟其它子空間牽扯上關係,所以很適合變成平行化的運算 ![image](https://hackmd.io/_uploads/rJAh0HSXll.png) 下面的程式碼是試著載入q中某個sequence(其實是某個單一token)特定的head data到一塊共享記憶體q_vecs當中 ```cuda! const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); } __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a // memory wall right before we use q_vecs ``` 以下會將相關的參數列出,會比較好具象化我們要的結果 > HEAD_SIZE = 128 > THREAD_GROUP_SIZE = 2 > NUM_VEC_PER_THREAD = 16 > q_vecs[2][16] 因為q_vecs是一塊共享記憶體(share memory),在同一個thread block裡頭的thread,是可以共用這一塊記憶體,當我們將資料讀取到這塊記憶體時,並不需要讓每個thread group重複的去從全域記憶體搬同樣的資料,只要讓前面32個thread各自去讀取32個vec進來即可,~~而後面96個thread,因為loop condition test不會通過(`i < NUM_VECS_PER_THREAD;`),所以也沒有它們的事~~,之後我們就可以利用相對應的k_vecs跟q_vecs做點積得出我們所要的結果 ![image](https://hackmd.io/_uploads/ryTTq-E7ll.png) 下面是k_vecs載入資料的過程,有一點類似於q_vecs ```cuda! for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { k_vecs[j] = *reinterpret_cast<const K_vec*>( k_ptr + offset1 * BLOCK_SIZE * x + offset2); } else { // Vector conversion from Quant_vec to K_vec. Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vec_quant, *k_scale); } } ``` 同樣的,我們先把一些參數列出來,以利後面的具象化 > NUM_TOKENS_PER_THREAD_GROUP = 1 > WARP_SIZE = 32 > BLOCK_SIZE = 16 > NUM_VECS_PER_THREAD = 16 > THREAD_GROUP_SIZE = 2 > VEC_SIZE = 4 > block_index = 0 因為我們要從k_cache將資料讀進k_vecs裡,所以我們先從原始碼中paged_attention_kernel函式看起,我們可以發現k_cache的layout是長這樣子的[num_blocks, num_kv_heads, head_size/x, block_size, x],我們把圖畫出來會比較清楚一些 ![image](https://hackmd.io/_uploads/r1m90kr7gg.png) 可以看到page attention在這邊做了一些改動,不像一般的layout是採用[num_seq, num_heads, head_size],這裡改成以block為單位,而原本的單純的head資料,現在會由不同的partial head組成,其partial head的個數會是head size/x個(x指的是每個thread group在之後的loop中,每個iteration所要讀取的element個數)。那為什麼要做這樣的改動呢?每個thread group要負責的vec現在會變成相鄰的狀態,如vec 0 (token 0), vec 1 (token 0), vec 0 (token 1), vec 1 (token 1), ...vec 0 (token 15), vec 1 (token 15),那這樣做有什麼好處呢?在cuda中記憶體的存取有所謂的coalesced memory access的問題,當我們有多個執行緒,若它們發出的memory access request是連續的位址,那可以將多個request合併成一個,藉以減少浪費,舉個例子來說,今天我們中午休息要發便當,與其讓每個人排隊,每次都詢問說你要排骨飯、雞腿飯還是控肉飯,倒不如一開始就說現在要發的是排骨飯,有訂的人上前領取,這樣是不是減少了一些不必要動作 利用這一段程式碼,每thread group會載入它們所需的部分資料 ![image](https://hackmd.io/_uploads/rJNXd-r7ll.png) ![image](https://hackmd.io/_uploads/SkB4tZS7xg.png) 如果你把官網的示意圖拿來做比較,會發現概念上是一樣的 ![image](https://hackmd.io/_uploads/HJF1-GrXxx.png) 我們己經知道載入q_vecs跟k_vecs的過程,之後在計算QK的點積(dot product)時,還需要一個外部的迴圈,讓k_vecs能載入不同的token,藉此計算q跟不同k之間的點積,得到相對應的logits ```cuda! q_vecs = ... for ... { k_ptr = ... for ... { k_vecs[i] = ... } ... float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); } ``` 在有了logits(=$softmax\left(\dfrac{QK^T}{\sqrt(d)}\right)$)之後,我們之要把這個權重套用到v,就可以得出我們最後的output。我們同樣把現在的參數設定列出來,以方便具象化 > V_VEC_SIZE = 8 > NUM_V_VECS_PER_ROW = 2 > NUM_ROWS_PER_ITER = 16 > NUM_ROWS_PER_THREAD = 8 > NUM_WARPS = 4 > accs[8] 由於在處理v值時,沒有thread group的概念,所以v_vec的大小,只要把16 bytes除以現在所用的資料型別(=2 bytes),就可得出一個v_vec的大小為8個elements。因為在此例中,每個row的大小為block size的大小(16),所以可以得出每個row有2個v_vec,而一個warp的32個執行緒,一次最多能處理16個row的資料,由於head size為128,我們可以得知若要處理全部的資料需要8次iteration,也就是每個thread需處理不同row,共8個v_vec ![image](https://hackmd.io/_uploads/B1b3UZumeg.png) ![image](https://hackmd.io/_uploads/rJ9TLWOQge.png) 從下面的程式中可以看出有一個取巧的地方在於,每次我們在算v的權重和時,是以多個block下去迭代的,以我們的例子來說,每個迴圈中的每個round,會去計算4個block中各個v的權重和,並累加至accs陣列中,所以我們跑完整個迴圈時,其實是把整體v的資訊分成4個部分 ```cuda= scalar_t zero_value; zero(zero_value); for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { // NOTE(woosuk): The block number is stored in int32. However, we cast it to // int64 because int32 can lead to overflow when this variable is multiplied // by large numbers (e.g., kv_block_stride). // For blocksparse attention: skip computation on blocks that are not // attended if constexpr (IS_BLOCK_SPARSE) { int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { continue; } } const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); } else { V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the // context, we should explicitly zero out the values since they may // contain NaNs. See // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); } } } ``` 因為偶數的執行緒有左半邊加權和,奇數的有右半邊的加權和,我們可以透過CUDA warp shuffle function,來讓某些特定的執行緒交換資訊 ```cuda= scalar_t zero_value; zero(zero_value); for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { // NOTE(woosuk): The block number is stored in int32. However, we cast it to // int64 because int32 can lead to overflow when this variable is multiplied // by large numbers (e.g., kv_block_stride). // For blocksparse attention: skip computation on blocks that are not // attended if constexpr (IS_BLOCK_SPARSE) { int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { continue; } } const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); } else { V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the // context, we should explicitly zero out the values since they may // contain NaNs. See // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); } } } ``` ![image](https://hackmd.io/_uploads/r1fTGWdXgg.png) ![image](https://hackmd.io/_uploads/H1NmBZ_mxe.png) 透過__shfl_xor_sync func交換資訊後,每個執行緒就可得到相對於各自的整個row的資訊。所以理論上,每個warp中,每個執行緒都掌握著NUM_ROWS_PER_THREAD個資訊 ![image](https://hackmd.io/_uploads/S13f8rOXge.png) 之後我們必須統整一下各個warp中的資訊,透過下面的程式碼即可達成。我們每次將迭代中的warp,分成上半部跟下半部,下半部的warp將資訊跟上半部相對應的warp融合,在經個幾個round後,warp 0會彙整完所有warp的資訊,此時這個資訊代表的是全部token block在做完$softmax\left(\frac{QK^T}{\sqrt(d)}\right)V$的權重和,最後將此結果寫出即可 ```cuda= // Perform reduction across warps. float* out_smem = reinterpret_cast<float*>(shared_mem); #pragma unroll for (int i = NUM_WARPS; i > 1; i /= 2) { int mid = i / 2; // Upper warps write to shared memory. if (warp_idx >= mid && warp_idx < i) { float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { dst[row_idx] = accs[i]; } } } __syncthreads(); // Lower warps update the output. if (warp_idx < mid) { const float* src = &out_smem[warp_idx * HEAD_SIZE]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { accs[i] += src[row_idx]; } } } __syncthreads(); } ``` ![image](https://hackmd.io/_uploads/ByJ3xfd7gg.png) ![image](https://hackmd.io/_uploads/S1x6xzu7gg.png) Reference: 1. [Understanding KV Cache and Paged Attention in LLMs: A Deep Dive into Efficient Inference][1] 2. [Mastering LLM Techniques: Inference Optimization][2] 3. [Efficient Memory Management for Large Language Model Serving with PagedAttention][3] 4. [https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cuh][4] 5. [vLLM Paged Attention][5] 6. [Using CUDA Warp-Level Primitives][6] 7. [CUDA C++ Programming Guide - 10.22. Warp Shuffle Functions][7] [1]: https://medium.com/my-musings-with-llms/understanding-kv-cache-and-paged-attention-in-llms-a-deep-dive-into-efficient-inference-62fa372432ce [2]: https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/ [3]: https://arxiv.org/abs/2309.06180 [4]: https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cuh [5]: docs.vllm.ai/en/latest/design/kernel/paged_attention.html [6]: https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/ [7]: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions