# Team 16 Final Project Report 110652019 林楷傑 313551030 吳秉澍 313551105 林睿騰 313554025 温柏萱 * Github Link: https://github.com/KJLdefeated/EAI_Final/tree/master * Model Page: * Main Model: https://huggingface.co/c1uc/Llama-3.2-3B-Instruct-lora-4bit-g128/tree/main * Draft Model for Speculative Decoding: https://huggingface.co/BensonW/EAI-Final-draft-model-gptq/tree/main ## Methodology ### Model Analysis ![image](https://hackmd.io/_uploads/rkK4Zpr-le.png) To speed up the inference of a language model, we first examine the structure and number of parameters of `meta-llama/Llama-3.2-3B-Instruct`. We observe that for each layer, the number of parameters have the follow relation: `mlp.gate_proj` $=$ `mlp.up_proj` $=$ `mlp.down_proj` $>$`self_attn.q_proj`$=$`self_attn.o_proj` $>$`self_attn.v_proj` $=$ `self_attn.k_proj`. ### Lora Rank Adaptation (LoRA) [LoRA](https://arxiv.org/abs/2106.09685) is a parameter-efficient fine-tuning technique proposed in 2019, having the ability to fine-tune large models with less than $1\%$ of its original parameters. LoRA works by injecting a pair of small, trainable low-rank matrices into the linear layers of the original model. That is, for weight matrix $W \in \mathbb{R}^{d \times k}$, two low-rank matrices $A \in \mathbb{R}^{d \times r}$ and $B \in \mathbb{R}^{r \times k}$. The tuned weight matrix $W'$ is calculated by $W + BA$. Since $r \ll d, k$ the number of additional parameters is $2r(d+k)$. In our work, we performed LoRA fine-tuning with [Salesforce/wikitext](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-2-raw-v1) for approximately 15 minutes on four RTX 3090 GPUs, selecting the checkpoint with the lowest perplexity as our final model. Notably, we found it crucial to use a lower learning rate and frequently monitor perplexity and model outputs, as neglecting these steps can result in a fine-tuned model producing almost no response. For detailed training configurations, please refer to the [`train_lora.py`](https://github.com/KJLdefeated/EAI_Final/blob/master/train_lora.py) in our GitHub repository. ### Generative Pre-trained Transformer Quantization (GPTQ) [GPTQ](https://arxiv.org/abs/2210.17323) was proposed in 2022 as a one-shot quantization method specifically for large language models. Designed to maintain full-percision under 4-bit or even 2-bit quantization, it yields superior performance on enhancing the throughput and enabling large language model inference on singe GPU. GPTQ regards the layer-wise post-quantization error as an optimization problem, that is, given a weight matrix $W \in \mathbb{R}^{d \times k}$ and calibration activations $X \in \mathbb{R}^{n \times d}$, GPTQ minimizes the error $$ \min_{\hat{W} \in \mathcal{Q}} \frac12 \| X(\hat{W} - W) \|^2 $$ where $\hat{W}$ represents the quantized weight matrix, $\mathcal{Q}$ is the set of integer weights representable at the aimed bit-width. In our work, we first merged the PEFT model weights and then quantize the model using 4-bit precision with a group size of 128, striking a balance between compression efficiency and model performance. ### Activation-aware Weight Quantization (AWQ) [AWQ](https://dl.acm.org/doi/abs/10.1145/3714983.3714987) targets weight-only INT-3/4 post-training quantisation of LLMs. The method embraces the idea of unequal contribution of weights, it preserves the $1\%$ of channels of high activation magnitudes, indicating that the channel was activated more frequently and stronger, therefore being more influential to the model performance. AWQ identifies the top $1\%$ high-magnitude weight channels and scale them by a constant $s\in [0,1]$, then performs group quantization with $4$ or less bit representation, but preserves the top $1\%$ with 16 bit floating points or 8 bit integers. Since it preserves the highly-influential channels with nearly full precision, the impact of quantization remains low and demonstrates high-fidelity representations after quantization. In our work, we first merged the PEFT model weights and then quantize the model using $4$-bit precision with a group size of $128$, striking a balance between compression efficiency and model performance. ### Speculative Decoding Speculative decoding is an inference speedup strategy that enhances the efficiency of LLMs by leveraging a two-model approach. In this method, a smaller, faster "draft" model generates multiple candidate tokens ahead of time. These speculative tokens are then verified by the larger, more accurate "target" model. If the target model accepts the draft tokens, they are included in the output; if not, the target model recalculates the correct tokens. This process allows multiple tokens to be processed in parallel, significantly reducing the time required for text generation without compromising the quality or accuracy of the output. In our work, we select `Llama3.2-1B-Instruct` as our base model, applying the same LoRA fine-tuning and quantization techniques used previously on the $3$B model. The resulting optimized model serves as our final draft model, significantly enhancing throughput during inference on the vllm framework. ### Our Method ![Blank diagram-2](https://hackmd.io/_uploads/rJGvjcpfel.png) To combine the above mentioned techniques, we begin by improving the model's accuracy on the wikitext-2-raw-v1 dataset using LoRA fine-tuning. This step reduces perplexity and effectively prepares the model for subsequent processing. To realize the above methods, we adopt [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang/tree/180ff5eecc2da2231eb3ef29f70aa8d62fd8e168) as inference frameworks. vLLM adopts the page-based allocation strategy, Page Attention, which breaks the monolithic key/value cache into small fixed-size memory blocks (pages). Instead of allocating large contiguous tensors for each request, PagedAttention enables dynamic reuse and defragmentation of memory blocks across concurrent requests, improving both GPU memory efficiency and scalability under mixed batch sizes and long prompts. In addition, vLLM conveniently supports speculative decoding, making it straightforward to implement this approach and boost inference speed. SGLang is another state-of-the art LLM inference engine that achieves high throughput. It breaks the memory bottleneck of KV cache by Radix Attention. Radix Attention groups and aligns KV caches based on radix-sort-like memory locality. This reduces cache-line thrashing and maximizes throughput, especially in scenarios with multiple concurrent sessions or varying sequence lengths. However, SGLang's speculative decoding support is limited to [EAGLE](https://arxiv.org/abs/2401.15077) and EAGLE3,, which requires the draft and target model to share identical hidden sizes. This restriction significantly narrows the selection of suitable draft models and consequently reduces throughput. As a result, we present both throughput and perplexity metrics in the next section. Among all configurations, the one achieving the highest throughput is selected as our final submission on Kaggle. ## Experimental Results GPTQ + vLLM + Speculative Decoding Result screenshot: PPL: 11.12 (test on the base model w/o framework) ![image](https://hackmd.io/_uploads/SkHR6O6Mxl.png) Throughput: 91.97 <!-- ![image](https://hackmd.io/_uploads/r17e0O6fgx.png) --> ![image](https://hackmd.io/_uploads/HyCYB2pflx.png) |Method| Framework |Throughput|PPL| |------|----------|-----------|---| | GPTQ | vllm | 84.01 | 11.12| | GPTQ + Speculative Decoding | vllm | 91.97 | 11.12| | GPTQ | SGLang | 90.1 | 11.12 | | GPTQ + Speculative Decoding | SGLang | 57.1 | 11.12 | |AWQ + Speculative Decoding| vllm| 71.87| 10.78| |AWQ | vllm |19.35 | 10.78| |HQQ | X | 3.9 | 11.23| <!-- ![image](https://hackmd.io/_uploads/HJ3QJDTGex.png) --> <!-- ![image](https://hackmd.io/_uploads/HJYMtuazgx.png) --> ![cmp](https://hackmd.io/_uploads/ryFG_3pMge.png) From the figure, we observe that GPTQ demonstrated superior performance than AWQ, in addition, the SGLang showed better results when speculative decoding was disabled. Therefore, we will compare the differences between vllm with SGLang in the next Insights section. <!-- However, SGLang is not panacea, since it only supports EAGLE as the speculative decoding algorithm, the two models loaded on GPU will compete for resource and therefore harm the performance. --> We observe that GPTQ consistently outperforms AWQ in terms of throughput under both vanilla and speculative decoding setups within the vLLM framework. SGLang without speculative decoding outperforms vLLM in throughput due to its low-level kernel optimizations and efficient memory access via Radix Attention, however, due to the limitations of speculative algorithms SGLang supports, the constraint of target model and draft model having identical hidden size yields resource competition between both models. <!-- | Layers | Setting 1 | Setting 2 | | :-----: | :---------: | :---------: | | `self_attn.q_proj` | 4 bits | 4 bits | | `self_attn.o_proj` | 4 bits | 4 bits | | `self_attn.k_proj` | 4 bits | 2 bits | | `self_attn.v_proj` | 4 bits | 2 bits | | `mlp.gate_proj` | 4 bits | 4 bits | | `mlp.up_proj` | 4 bits | 4 bits | | `mlp.down_proj` | 4 bits | 4 bits | |Throughput | 504.6 | 511.1 | |Perplexity | 9.59 | 11.23 | --> ## Team Member Contributions |Method | Framework| Contributors | |-----------------------------|----------|--------------| | GPTQ | vLLm | 林睿騰 | | GPTQ + Speculative Decoding | vLLm | 吳秉澍 | | GPTQ | SGLang | 温柏萱 | | GPTQ + Speculative Decoding | SGLang | 温柏萱 | | AWQ | vLLm | 林睿騰 | | AWQ + Speculative Decoding | vLLm | 吳秉澍 | | HQQ | X | 林楷傑 | | LoRA Finetuning | X | 林楷傑 | ## Insights / Explorations ### The Discovery of the Flawed Score Calculation Metric Initially, we scored above 500 using merely HQQ + LoRA fine-tuning (Yup, we were the two mysterous submissions on Kaggle). However, we noticed that if we continue to output `\n`, it would contribute siginificantly to the throughput without harming the perplexity. It was due to the score calculation metric of using the constant `max_new_tokens` to represent the number of generated tokens per second instead of the actual number of generated tokens. Technically, we have "score-hacked" our way to glory. But as responsible and honest NYCU students, we dutifully reported the issue to the TAs with a mild heartbreak. ![image](https://hackmd.io/_uploads/B1vu4RTMxx.png) ### vLLM vs. SGLang |Framework | KV Cache | Batch Design | | ---------| ---------|--------------| | SGLang | Radix Attention| persistent batching | | vLLM | Page Attention | continuous batching | To fully understand the performance difference between vLLM and SGLang, we explore the difference between Radix Attention and Page Attention, as well as the JIT key-value layout planning. #### Overview | Algorithm | **Radix Attention** | **Paged Attention** | | ---- | ------------------------------- | --------------------------- | | Core Idea | Optimize attention calculation and memory efficiency | Efficiently manage k-v cache in a long context | | Background | Porposed in FlashAttention-2, is a design for efficiency | Porposed in Mistral model, to handle reasoning in a long context | | Mechanism | Split a series into blocks, and use a divide-and-conqur to combine | Store k-v cache in pages, load or save when needed | | Computation | Calculate softmax separately and combine | Only load the required part into GPU, just like virtual memory | | Memory Management | Optimizes memory layout and softmax computation | Split k-v cache into pages, manage memory dynamically | | Key Advantage | Faster training and lower memory usage | longer context | | Large context support | No | Yes | | Inference speed impact | Yes | Yes | | Change k-v cache structure | No | Yes | #### Summary | Metric | Radix Attention | Paged Attention | | ------ | --------------- | --------------- | | Problem solved | Softmax computation efficiency | Memory menegement of a long context | | Core idea | Divide-and-conquer and numerical stability | Paging in virtual memory | | Best use case | High-speed training and inference | Long context tasks, e.g. document summary, RAG | The JIT Layout planning dynamically computes the optimal memory organization at runtime just before the decoding step begins, such mechanism takes the current batch shape, prompt lengths, memory alignment requirements of the GPU and hardware-friendly memory access patterns into consideration to arrange tokens across threads and warps, accommodate each attention head’s K/V vectors and combine memory reads/writes to reduce latency. From the experiment without Speculative Decoding, SGLang surpassed vLLM by $6.19$ tokens/s, demonstrating the effectiveness of JIT layout planning! ### vLLM environment variable and compilation config In vLLM, we have several backend options that can be configured using an environment variable `VLLM_ATTENTION_BACKEND` with options `TORCH_SDPA`, `FLASH_ATTN`, `XFORMERS`, `ROCM_FLASH`, `FLASHINFER`, and `FLASHMLA`. However, not all backends are compatible with T4 GPUs. We found the default `XFORMERS` yields the highest throughput among the compatible options. On the T4 server provided by NYCU, we observed that vLLM occasionally crashes due to the default cudagraph settings in the compilation configuration. To resolve this issue, you can either enforce eager mode or adjust the cudagraph settings. To achieve higher throughput on the "ta004" server, we apply the following compilation configuration: ``` compilation_config = { "cudagraph_capture_sizes": [1, 2, 4, 8, 16], "max_capture_size": 16, } ``` This setup allows for efficient execution by leveraging CUDA graphs. Increasing the capture size may boost performance, while further reducing it can help prevent potential runtime errors, especially on more constrained hardware like T4 GPUs. ### Additional Attempts #### Flash Attention Attempt (Fail) Since NVIDIA T4 does not support built-in flash-attention2 in hugginface transformers, we have tried to implement our own. We tried to reference the implementation of this [repo](https://github.com/hkproj/triton-flash-attention) and replace the attention interface in llama with our self defined attention interface. But somehow the attention output would become nan which is calculation error. I speculate the reason is that llama is using multi-query attention rather than traditional attention. The open source implementation of triton FA2 kernel mostly does not support for MQA/GQA. ```python= import torch import torch.nn.functional as F import triton import triton.language as tl import itertools import argparse from functools import partial from typing import Optional, Tuple DEVICE = 'cuda' @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, # K_block_ptr, V_block_ptr, # start_m, qk_scale, # dropout_p, dropout_seed, # BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # N_CTX: tl.constexpr, USE_DROPOUT: tl.constexpr, IS_BF16: tl.constexpr): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) # causal = False else: lo, hi = 0, N_CTX K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) dropout_scale = 1.0 / (1.0 - dropout_p) if USE_DROPOUT else 1.0 # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(K_block_ptr) qk = tl.dot(q, k) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] else: m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) qk = qk * qk_scale - m_ij[:, None] p = tl.math.exp2(qk) # -- apply dropout if specified at compile time -- if USE_DROPOUT: # Create a unique offset for each element in the p matrix # The offset needs to be unique across both dimensions of p row_offsets = offs_m[:, None] + start_m * BLOCK_M col_offsets = (start_n + offs_n[None, :]) # Combine row and column offsets into a unique offset # Multiply row offset by N_CTX to ensure uniqueness across the matrix combined_offsets = row_offsets * N_CTX + col_offsets # Generate the dropout mask using combined offsets rng = tl.rand(dropout_seed, combined_offsets) dropout_mask = rng > dropout_p # Apply the dropout and scale by 1/(1-p) to maintain expected values p = tl.where(dropout_mask, p / (1.0 - dropout_p), 0.0) l_ij = tl.sum(p, 1) # -- update m_i and l_i alpha = tl.math.exp2(m_i - m_ij) l_i = l_i * alpha + l_ij # -- update output accumulator -- acc = acc * alpha[:, None] # update acc v = tl.load(V_block_ptr) if IS_BF16: p = p.to(tl.bfloat16) else: p = p.to(tl.float16) acc = tl.dot(p, v, acc) # update m_i and l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i # We don't run auto-tuning every time to keep the tutorial fast. Keeping # the code below and commenting out the equivalent parameters is convenient for # re-tuning. configs = [ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ for BM in [32, 64, 128]\ for BN in [16, 32, 64]\ for s in ([3, 4, 5, 7])\ for w in [4, 8]\ ] def keep(conf): BLOCK_M = conf.kwargs["BLOCK_M"] BLOCK_N = conf.kwargs["BLOCK_N"] if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: return False if BLOCK_M < BLOCK_N: return False return True @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit def _attn_fwd(Q, K, V, sm_scale, dropout_p, dropout_seed, M, Out, # stride_qz, stride_qh, stride_qm, stride_qk, # stride_kz, stride_kh, stride_kn, stride_kk, # stride_vz, stride_vh, stride_vk, stride_vn, # stride_oz, stride_oh, stride_om, stride_on, # Z, H, N_CTX, # HEAD_DIM: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # STAGE: tl.constexpr, # USE_DROPOUT: tl.constexpr, # IS_BF16: tl.constexpr ): tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H off_h = off_hz % H qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) v_order: tl.constexpr = (1, 0) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, HEAD_DIM), order=v_order, ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, shape=(HEAD_DIM, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(HEAD_DIM, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) # load scales qk_scale = sm_scale qk_scale *= 1.44269504 # 1/log(2) # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # start_m, qk_scale, dropout_p, dropout_seed, # BLOCK_M, HEAD_DIM, BLOCK_N, # 4 - STAGE, offs_m, offs_n, N_CTX, USE_DROPOUT, IS_BF16) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # start_m, qk_scale, dropout_p, dropout_seed, # BLOCK_M, HEAD_DIM, BLOCK_N, # 2, offs_m, offs_n, N_CTX, USE_DROPOUT, IS_BF16) # epilogue m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] m_ptrs = M + off_hz * N_CTX + offs_m tl.store(m_ptrs, m_i) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) def flash_attn_kernel( q, k, v, causal, sm_scale, ): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} o = torch.empty_like(q) stage = 3 if causal else 1 extra_kern_args = {} M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) # Determine if we should use dropout at compile time USE_DROPOUT = False _attn_fwd[grid]( q, k, v, sm_scale, 0.0, 0, M, o, # q.stride(0), q.stride(1), q.stride(2), q.stride(3), # k.stride(0), k.stride(1), k.stride(2), k.stride(3), # v.stride(0), v.stride(1), v.stride(2), v.stride(3), # o.stride(0), o.stride(1), o.stride(2), o.stride(3), # q.shape[0], q.shape[1], # N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # STAGE=stage, # USE_DROPOUT=USE_DROPOUT, # IS_BF16=(True if q.dtype == torch.bfloat16 else False), **extra_kern_args) return o def my_flash_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: # This is before the transpose seq_len = query.shape[2] # # FA2 uses non-transposed inputs # query = query.transpose(1, 2) # key = key.transpose(1, 2) # value = value.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (usually our RMSNorm modules handle it correctly) target_dtype = None if query.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(module.config, "_pre_quantization_dtype"): target_dtype = module.config._pre_quantization_dtype else: target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice kwargs.pop("is_causal", None) BLOCK_SIZE = 128 q = query k = key v = value q_repeats_per_kv = 3 batch_size, num_q_heads, seq_length, head_dim = q.shape num_kv_heads = k.shape[1] seq_len_q = seq_length seq_len_k = k.shape[2] original_seq_len_q = seq_len_q original_seq_len_k = seq_len_k pad_q = (BLOCK_SIZE - seq_len_q % BLOCK_SIZE) % BLOCK_SIZE pad_k = (BLOCK_SIZE - seq_len_k % BLOCK_SIZE) % BLOCK_SIZE if pad_q > 0: q_padding = torch.zeros(batch_size, num_q_heads, pad_q, head_dim, dtype=q.dtype, device=q.device) q = torch.cat([q, q_padding], dim=2) seq_len_q += pad_q if pad_k > 0: k_padding = torch.zeros(batch_size, num_kv_heads, pad_k, head_dim, dtype=k.dtype, device=k.device) v_padding = torch.zeros(batch_size, num_kv_heads, pad_k, head_dim, dtype=v.dtype, device=v.device) k = torch.cat([k, k_padding], dim=2) v = torch.cat([v, v_padding], dim=2) seq_len_k += pad_k seq_length = max(seq_len_q, seq_len_k) if q_repeats_per_kv > 1: # Repeat each key and value head to match the number of query heads # Map each query head to its corresponding k/v head # For example, q heads [0,1,2] might map to k/v head 0, q heads [3,4,5] to k/v head 1, etc. # Reshape k and v to repeat each head q_repeats_per_kv times # Using expand is memory-efficient as it doesn't actually copy the data k = k.unsqueeze(2) # [batch_size, num_kv_heads, 1, seq_length, head_dim] k = k.expand(batch_size, num_kv_heads, q_repeats_per_kv, seq_len_k, head_dim) k = k.reshape(batch_size, num_q_heads, seq_len_k, head_dim) v = v.unsqueeze(2) # [batch_size, num_kv_heads, 1, seq_length, head_dim] v = v.expand(batch_size, num_kv_heads, q_repeats_per_kv, seq_len_k, head_dim) v = v.reshape(batch_size, num_q_heads, seq_len_k, head_dim) # # Transpose q, k, v from [batch, seq_len, heads, head_dim] to [batch, heads, seq_len, head_dim] # # which is what the Triton implementation expects # q = q.transpose(1, 2).contiguous() # Ensure contiguity after transpose # k = k.transpose(1, 2).contiguous() # v = v.transpose(1, 2).contiguous() q = q.contiguous() if not q.is_contiguous() else q k = k.contiguous() if not k.is_contiguous() else k v = v.contiguous() if not v.is_contiguous() else v out = flash_attn_kernel( q, k, v, causal=module.is_causal, sm_scale=scaling, ) out = out.transpose(1, 2).contiguous() if pad_q > 0: out = out[:, :original_seq_len_q, :, :] return out, None ``` #### GPTQ Attempt with less tput The reason of failing to achieve higher score as one of the previous approaches is due to the misuse of draft model `BensonW/EAI-Final-draft-model`, which quantized by AWQ, while we are attempting to use GPTQ in this scenario. This result suggest that the draft model must be significantly faster than the target model; otherwise, it may harm the overall performance. ``` Prompt: How to learn a new language? Response: but rewarding experience. Here are some steps you can take to learn a new language: 1. **Set your goals**: Determine why you want to learn a new language and what you want to achieve. Are you traveling to a foreign country? Do you want to improve your career opportunities? Are you interested in reading classic literature in the original language? 2. **Choose your language**: Select a language that interests you and is feasible to learn. Consider factors such as language difficulty, cultural relevance, and availability of resources. 3. **Learn the basics**: Start with the basics of the language, such as the alphabet, grammar rules, and common phrases. You can find many resources online, such as language learning apps, websites, and language exchange programs. 4. **Practice regularly**: Practice speaking, writing, listening, and reading in the target language regularly. You can use language learning apps, watch TV shows and movies with subtitles, listen to podcasts and radio shows, and read books and articles in the target language. 5. **Immerse yourself in the language**: Surround yourself with the language as much as possible. Listen to music and podcasts in the target language, watch TV shows and movies with subtitles, and speak with native speakers. Time Record: [3.75894140625, 3.686145751953125, 3.75512939453125, 3.791216064453125, 3.719895751953125, 3.71812890625, 3.791583740234375, 3.675305419921875, 3.639480712890625, 3.686208740234375] Throughput Record: [65.97602175645778, 67.27894573040032, 66.04299717638828, 65.41436725943328, 66.66853496359086, 66.7002156872839, 65.40802392634747, 67.47738532306022, 68.14158929915807, 67.2777960979583] toks/s Throughput: 66.7 toks/s ``` ```python= from auto_gptq import AutoGPTQForCausalLM from transformers import AutoTokenizer from vllm import LLM, SamplingParams import torch import random import numpy as np from tqdm.auto import tqdm from datasets import load_dataset import csv # === Set Up === torch.manual_seed(0) random.seed(0) max_new_tokens = 256 sp = SamplingParams(max_tokens=max_new_tokens, temperature=0.0) sp_ppl = SamplingParams(max_tokens=1, prompt_logprobs=1, temperature=0.0) # === Load GPTQ quantized model === model_name = "AlisonWen/Llama-3.2-3B-Instruct-lora-gptq" tokenizer = AutoTokenizer.from_pretrained(model_name) # model = LLM(model=model_path) model = LLM(model=model_name, dtype="auto", max_model_len=2048, gpu_memory_utilization=0.8, tensor_parallel_size=1, speculative_config={ "model": "BensonW/EAI-Final-draft-model", "num_speculative_tokens": 6, }, ) # === Warm-up === prompt = "Explain what AI is." tputs, time_record = [], [] for _ in tqdm(range(5), desc="Warm Up..."): _ = model.generate([prompt], sp) # === Inference === test_prompt = "How to learn a new language?" input_ids = tokenizer(test_prompt, return_tensors="pt")["input_ids"] for _ in tqdm(range(10), desc="Test Inference"): torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() generated = model.generate([test_prompt], sp)[0] output_ids = generated.outputs[0].token_ids output_len = len(output_ids) - input_ids.shape[1] end.record() torch.cuda.synchronize() elapsed_ms = start.elapsed_time(end) tput = output_len / (elapsed_ms / 1000) tputs.append(tput) time_record.append(elapsed_ms / 1000) response = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True) sorted_tputs = np.sort(tputs)[2:-2] org_tput = np.mean(sorted_tputs) print(f'Prompt: {test_prompt}\nResponse: {response}\n') print(f'Time Record: {time_record}') print(f'Throughput Record: {tputs} toks/s\n') print(f'Throughput: {org_tput:.1f} toks/s') # === Save Result === with open("result.csv", mode="w", newline="") as file: writer = csv.writer(file) writer.writerow(["Id", "value"]) writer.writerow([1, round(org_tput, 1)]) ``` #### Different Configurations Handling different output lengths with `rope_scaling` have lowered the throughput. ```python= model = LLM(model=model_name, dtype="auto", max_model_len=2048, gpu_memory_utilization=0.9, tensor_parallel_size=1, rope_scaling={"rope_type": "dynamic", "factor": 1.5}, # Corrected RoPE key hf_overrides={}, # handles different length generation # enforce_eager=True, speculative_config={ "model": "BensonW/EAI-Final-draft-model", "num_speculative_tokens": 6, }, ) ``` Results ``` Prompt: How to learn a new language? Response: but rewarding experience. Here are some steps you can follow to learn a new language: 1. **Set your goals**: Identify why you want to learn the language and what you want to achieve. Are you traveling, working, or studying abroad? Do you want to improve your career prospects or connect with family and friends? Setting clear goals will help you stay motivated. 2. **Choose your resources**: There are many resources available to learn a new language, including: * Language learning apps (e.g., Duolingo, Babbel) * Language exchange websites (e.g., italki, Conversation Exchange) * Language classes (e.g., community college, language school) * Textbooks and language learning books * Online courses and tutorials 3. **Start with the basics**: Begin with the fundamentals of the language, including: * Alphabet and pronunciation * Basic grammar rules (e.g., verb conjugation, sentence structure) * Common vocabulary (e.g., greetings, introductions) 4. **Practice consistently**: Make language learning a regular part of your routine, even if it's just 10-15 minutes a day. Practice speaking, listening, reading, and writing to reinforce Time Record: [3.917848876953125, 3.916132568359375, 3.99327734375, 3.874427978515625, 3.858979248046875, 3.912814697265625, 3.83773095703125, 3.84310791015625, 3.88774462890625, 3.815467041015625] Throughput Record: [63.300042392872314, 63.32778466278969, 62.104376593865275, 64.00944897548825, 64.26569931038601, 63.38148345570997, 64.62151796900457, 64.53110498006208, 63.79019808967508, 64.99859580335564] toks/s Throughput: 63.884286579018514 toks/s ```