[vllm] benchmark_serving.py
===
###### tags: `LLM / inference`
###### tags: `LLM`, `inference`, `推論`, `vllm`, `benchmark_serving.py`
<br>
[TOC]
<br>
## benchmark_serving.py
> https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py
```=
$ python3 benchmark_serving.py -h
usage: benchmark_serving.py [-h]
[--backend {tgi,vllm,lmdeploy,deepspeed-mii,openai,openai-chat,tensorrt-llm,scalellm}]
[--base-url BASE_URL] [--host HOST] [--port PORT]
[--endpoint ENDPOINT] [--dataset DATASET]
[--dataset-name {sharegpt,sonnet,random}]
[--dataset-path DATASET_PATH] --model MODEL [--tokenizer TOKENIZER]
[--best-of BEST_OF] [--use-beam-search] [--num-prompts NUM_PROMPTS]
[--sharegpt-output-len SHAREGPT_OUTPUT_LEN]
[--sonnet-input-len SONNET_INPUT_LEN]
[--sonnet-output-len SONNET_OUTPUT_LEN]
[--sonnet-prefix-len SONNET_PREFIX_LEN]
[--random-input-len RANDOM_INPUT_LEN]
[--random-output-len RANDOM_OUTPUT_LEN]
[--random-range-ratio RANDOM_RANGE_RATIO]
[--request-rate REQUEST_RATE] [--seed SEED] [--trust-remote-code]
[--disable-tqdm] [--save-result] [--metadata [KEY=VALUE ...]]
[--result-dir RESULT_DIR] [--result-filename RESULT_FILENAME]
Benchmark the online serving throughput.
options:
-h, --help show this help message and exit
--backend {tgi,vllm,lmdeploy,deepspeed-mii,openai,openai-chat,tensorrt-llm,scalellm}
--base-url BASE_URL Server or API base url if not using http host and port.
--host HOST
--port PORT
--endpoint ENDPOINT API endpoint.
--dataset DATASET Path to the ShareGPT dataset, will be deprecated in the next release.
--dataset-name {sharegpt,sonnet,random}
Name of the dataset to benchmark on.
--dataset-path DATASET_PATH
Path to the dataset.
--model MODEL Name of the model.
--tokenizer TOKENIZER
Name or path of the tokenizer, if not using the default tokenizer.
--best-of BEST_OF Generates `best_of` sequences per prompt and returns the best one.
--use-beam-search
--num-prompts NUM_PROMPTS
Number of prompts to process.
--sharegpt-output-len SHAREGPT_OUTPUT_LEN
Output length for each request. Overrides the output length from the
ShareGPT dataset.
--sonnet-input-len SONNET_INPUT_LEN
Number of input tokens per request, used only for sonnet dataset.
--sonnet-output-len SONNET_OUTPUT_LEN
Number of output tokens per request, used only for sonnet dataset.
--sonnet-prefix-len SONNET_PREFIX_LEN
Number of prefix tokens per request, used only for sonnet dataset.
--random-input-len RANDOM_INPUT_LEN
Number of input tokens per request, used only for random sampling.
--random-output-len RANDOM_OUTPUT_LEN
Number of output tokens per request, used only for random sampling.
--random-range-ratio RANDOM_RANGE_RATIO
Range of sampled ratio of input/output length, used only for random
sampling.
--request-rate REQUEST_RATE
Number of requests per second. If this is inf, then all the requests are
sent at time 0. Otherwise, we use Poisson process to synthesize the
request arrival times.
--seed SEED
--trust-remote-code Trust remote code from huggingface
--disable-tqdm Specify to disable tqdm progress bar.
--save-result Specify to save benchmark results to a json file
--metadata [KEY=VALUE ...]
Key-value pairs (e.g, --metadata version=0.3.3 tp=1) for metadata of this
run to be saved in the result JSON file for record keeping purposes.
--result-dir RESULT_DIR
Specify directory to save benchmark json results.If not specified, results
are saved in the current directory.
--result-filename RESULT_FILENAME
Specify the filename to save benchmark json results.If not specified,
results will be saved in
{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json format.
```
<br>
<hr>
<br>
## 輸入參數的整理表
### 通用參數
| 參數 | 類型 | 預設值 | 必要 | 說明 |
|------------------------|-------------------|-----------------|------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--backend` | 字串 (`str`) | `"vllm"` | 否 | 後端服務類型。可選值為 `"vllm"` 等(依實際可用後端而定)。 |
| `--base-url` | 字串 (`str`) | 無 (`None`) | 否 | 如果不使用 `--host` 和 `--port`,則指定服務器或 API 的基礎 URL。 |
| `--host` | 字串 (`str`) | `"localhost"` | 否 | 服務器主機地址。 |
| `--port` | 整數 (`int`) | `8000` | 否 | 服務器埠號。 |
| `--endpoint` | 字串 (`str`) | `"/v1/completions"` | 否 | API 端點。 |
| `--dataset` | 字串 (`str`) | 無 (`None`) | 否 | ShareGPT 數據集的路徑,將在下一版本中棄用。 |
| `--dataset-name` | 字串 (`str`) | `"sharegpt"` | 否 | 用於基準測試的數據集名稱。可選值為 `"sharegpt"`、`"sonnet"`、`"random"`、`"hf"`。 |
| `--dataset-path` | 字串 (`str`) | 無 (`None`) | 否 | ShareGPT/Sonnet 數據集的路徑,或使用 Hugging Face 數據集時的數據集 ID。 |
| `--max-concurrency` | 整數 (`int`) | 無 (`None`) | 否 | 最大並發請求數。用於模擬高層元件限制最大並發請求數的環境。 |
| `--model` | 字串 (`str`) | **無** | 是 | 模型的名稱。 |
| `--tokenizer` | 字串 (`str`) | 無 (`None`) | 否 | 分詞器的名稱或路徑,如果不使用預設的分詞器。 |
| `--best-of` | 整數 (`int`) | `1` | 否 | 每個提示生成的序列數,並返回最佳的序列。 |
| `--use-beam-search` | 開關 (`flag`) | `False` | 否 | 使用 Beam Search 演算法。 |
| `--num-prompts` | 整數 (`int`) | `1000` | 否 | 要處理的提示數量。 |
| `--logprobs` | 整數 (`int`) | 無 (`None`) | 否 | 每個標記要計算並返回的 logprob 數量。如果未指定,則根據是否啟用 Beam Search 決定計算方式。 |
| `--request-rate` | 浮點數 (`float`) | `inf` | 否 | 每秒請求數量。如果為無限大,則所有請求同時發送。否則,使用泊松過程模擬請求到達時間。 |
| `--seed` | 整數 (`int`) | `0` | 否 | 隨機數生成器的種子。 |
| `--trust-remote-code` | 開關 (`flag`) | `False` | 否 | 信任來自 Hugging Face 的遠端程式碼。 |
| `--disable-tqdm` | 開關 (`flag`) | `False` | 否 | 禁用 tqdm 進度條。 |
| `--profile` | 開關 (`flag`) | `False` | 否 | 使用 Torch Profiler。端點必須以 `VLLM_TORCH_PROFILER_DIR` 啟動以啟用分析器。 |
| `--save-result` | 開關 (`flag`) | `False` | 否 | 保存基準測試結果到 JSON 檔案。 |
| `--metadata` | `KEY=VALUE` 格式 | 無 | 否 | 以鍵值對形式(如 `--metadata version=0.3.3 tp=1`)指定此運行的元數據,將保存到結果 JSON 檔案中以供記錄。 |
| `--result-dir` | 字串 (`str`) | 無 (`None`) | 否 | 指定保存基準測試結果 JSON 檔案的目錄。如果未指定,結果將保存到當前目錄。 |
| `--result-filename` | 字串 (`str`) | 無 (`None`) | 否 | 指定保存基準測試結果 JSON 檔案的檔名。如果未指定,結果將以預設格式保存。 |
| `--ignore-eos` | 開關 (`flag`) | `False` | 否 | 發送基準測試請求時設置 `ignore_eos` 標誌。警告:`deepspeed_mii` 和 `tgi` 不支援此功能。 |
| `--percentile-metrics` | 字串 (`str`) | `"ttft,tpot,itl"` | 否 | 要報告百分位數的指標列表。允許的指標名稱為 `"ttft"`、`"tpot"`、`"itl"`、`"e2el"`。 |
| `--metric-percentiles` | 字串 (`str`) | `"99"` | 否 | 要報告的百分位數列表,例如 `"25,50,75"`。使用 `--percentile-metrics` 選擇指標。 |
| `--goodput` | `KEY:VALUE` 格式 | 無 | 否 | 指定良好吞吐量的服務級別目標(SLO),格式為 `"指標名稱:值"`,值以毫秒為單位。允許的指標名稱為 `"ttft"`、`"tpot"`、`"e2el"`。 |
### Sonnet 數據集選項
| 參數 | 類型 | 預設值 | 必要 | 說明 |
|------------------------|---------------|--------|------|--------------------------------------------------------------|
| `--sonnet-input-len` | 整數 (`int`) | `550` | 否 | 每個請求的輸入標記數,僅用於 Sonnet 數據集。 |
| `--sonnet-output-len` | 整數 (`int`) | `150` | 否 | 每個請求的輸出標記數,僅用於 Sonnet 數據集。 |
| `--sonnet-prefix-len` | 整數 (`int`) | `200` | 否 | 每個請求的前綴標記數,僅用於 Sonnet 數據集。 |
### ShareGPT 數據集選項
| 參數 | 類型 | 預設值 | 必要 | 說明 |
|--------------------------|---------------|------------|------|--------------------------------------------------------------|
| `--sharegpt-output-len` | 整數 (`int`) | 無 (`None`) | 否 | 每個請求的輸出長度,將覆蓋 ShareGPT 數據集中的輸出長度。 |
### Random 數據集選項
| 參數 | 類型 | 預設值 | 必要 | 說明 |
|--------------------------|-------------------|--------|------|------------------------------------------------------------------------------------|
| `--random-input-len` | 整數 (`int`) | `1024` | 否 | 每個請求的輸入標記數,僅用於隨機採樣。 |
| `--random-output-len` | 整數 (`int`) | `128` | 否 | 每個請求的輸出標記數,僅用於隨機採樣。 |
| `--random-range-ratio` | 浮點數 (`float`) | `1.0` | 否 | 輸入/輸出長度的採樣比例範圍,僅用於隨機採樣。 |
| `--random-prefix-len` | 整數 (`int`) | `0` | 否 | 隨機上下文前的固定前綴標記數。 |
### Hugging Face 數據集選項
| 參數 | 類型 | 預設值 | 必要 | 說明 |
|------------------------|---------------|------------|------|--------------------------------------------------------------|
| `--hf-subset` | 字串 (`str`) | 無 (`None`)| 否 | Hugging Face 數據集的子集名稱。 |
| `--hf-split` | 字串 (`str`) | 無 (`None`)| 否 | Hugging Face 數據集的分割名稱(如 `"train"`、`"test"`)。 |
| `--hf-output-len` | 整數 (`int`) | 無 (`None`)| 否 | 每個請求的輸出長度,將覆蓋從 Hugging Face 數據集中採樣的輸出長度。 |
**註**:`必要`欄位為 **是** 表示該參數為必須提供,為 **否** 則為可選參數。
<br>
<hr>
<br>
## 函數列表與用法
> 程式中主要函數的用法整理:
### 0. [總覽] 執行入口
根據 `args.dataset_name` ,呼叫不同 function 作為開始
| `args.dataset_name` | function |
|---------------------|----------|
| `sharegpt` | `sample_sharegpt_requests` |
| `sonnet` | `sample_sonnet_requests` |
| `hf` | `sample_hf_requests` |
| `random` | `sample_random_requests` |
<br>
### 1. `sample_sharegpt_requests`
- **功能**:從 ShareGPT 數據集中抽樣請求,用於基準測試。
- **參數**:
- `dataset_path` (`str`):ShareGPT 數據集的路徑。
- `num_requests` (`int`):要抽樣的請求數量。
- `tokenizer` (`PreTrainedTokenizerBase`):用於分詞的 tokenizer。
- `fixed_output_len` (`Optional[int]`, 預設值:`None`):固定的輸出長度,若為 `None`,則使用數據集中提供的輸出長度。
- **返回值**:`List[Tuple[str, int, int, None]]`,每個元素包含:
- 輸入文本 (`str`)
- 輸入長度 (`int`)
- 輸出長度 (`int`)
- 多媒體內容(此處為 `None`)
### 2. `sample_sonnet_requests`
- **功能**:從 Sonnet 數據集中生成請求,用於基準測試。
- **參數**:
- `dataset_path` (`str`):Sonnet 數據集的路徑。
- `num_requests` (`int`):要生成的請求數量。
- `input_len` (`int`):每個請求的輸入標記數。
- `output_len` (`int`):每個請求的輸出標記數。
- `prefix_len` (`int`):每個請求的前綴標記數。
- `tokenizer` (`PreTrainedTokenizerBase`):用於分詞的 tokenizer。
- **返回值**:`List[Tuple[str, str, int, int, None]]`,每個元素包含:
- 原始輸入文本 (`str`)
- 格式化後的輸入文本 (`str`)
- 輸入長度 (`int`)
- 輸出長度 (`int`)
- 多媒體內容(此處為 `None`)
### 3. `sample_hf_requests`
- **功能**:從 Hugging Face 數據集中抽樣請求,用於基準測試。
- **參數**:
- `dataset_path` (`str`):HF 數據集的 ID。
- `dataset_subset` (`str`):HF 數據集的子集名稱。
- `dataset_split` (`str`):HF 數據集的分割(如 `"train"`、`"test"`)。
- `num_requests` (`int`):要抽樣的請求數量。
- `tokenizer` (`PreTrainedTokenizerBase`):用於分詞的 tokenizer。
- `random_seed` (`int`):隨機種子。
- `fixed_output_len` (`Optional[int]`, 預設值:`None`):固定的輸出長度,若為 `None`,則使用數據集中提供的輸出長度。
- **返回值**:`List[Tuple[str, int, int, Optional[Dict[str, Collection[str]]]]]`,每個元素包含:
- 輸入文本 (`str`)
- 輸入長度 (`int`)
- 輸出長度 (`int`)
- 多媒體內容(如有)
### 4. `sample_random_requests`
- **功能**:隨機生成請求,用於基準測試。
- **參數**:
- `prefix_len` (`int`):固定前綴的標記數。
- `input_len` (`int`):每個請求的輸入標記數。
- `output_len` (`int`):每個請求的輸出標記數。
- `num_prompts` (`int`):要生成的請求數量。
- `range_ratio` (`float`):輸入/輸出長度的隨機範圍比例。
- `tokenizer` (`PreTrainedTokenizerBase`):用於分詞的 tokenizer。
- **返回值**:`List[Tuple[str, int, int]]`,每個元素包含:
- 輸入文本 (`str`)
- 輸入長度 (`int`)
- 輸出長度 (`int`)
### 5. `get_request`
- **功能**:生成異步請求,用於模擬請求的到達時間。
- **參數**:
- `input_requests` (`List[Tuple[str, int, int]]`):請求列表。
- `request_rate` (`float`):每秒請求數量。
- **返回值**:`AsyncGenerator[Tuple[str, int, int], None]`,異步生成請求。
### 6. `calculate_metrics`
- **功能**:計算基準測試的各種性能指標。
- **參數**:
- `input_requests`:請求的輸入數據列表。
- `outputs`:請求的輸出結果列表。
- `dur_s` (`float`):基準測試的持續時間(秒)。
- `tokenizer`:用於計算標記數的 tokenizer。
- `selected_percentile_metrics` (`List[str]`):選擇要計算百分位數的指標名稱列表。
- `selected_percentiles` (`List[float]`):要計算的百分位數列表。
- `gootput_config_dict` (`Dict[str, float]`):良好吞吐量的配置字典。
- **返回值**:`Tuple[BenchmarkMetrics, List[int]]`,包含計算的指標和實際輸出長度列表。
### 7. `benchmark`
- **功能**:執行基準測試的主要函數。
- **參數**:
- `backend` (`str`):後端服務類型。
- `api_url` (`str`):API 的完整 URL。
- `base_url` (`str`):API 的基礎 URL。
- `model_id` (`str`):模型的名稱。
- `tokenizer`:用於分詞的 tokenizer。
- `input_requests`:請求的輸入數據列表。
- `logprobs` (`Optional[int]`):要計算的 logprob 數量。
- `best_of` (`int`):每個請求生成的序列數。
- `request_rate` (`float`):每秒請求數量。
- `disable_tqdm` (`bool`):是否禁用進度條。
- `profile` (`bool`):是否啟用性能分析。
- `selected_percentile_metrics` (`List[str]`):選擇要計算百分位數的指標名稱列表。
- `selected_percentiles` (`List[float]`):要計算的百分位數列表。
- `ignore_eos` (`bool`):是否忽略 EOS(結束標記)。
- `gootput_config_dict` (`Dict[str, float]`):良好吞吐量的配置字典。
- `max_concurrency` (`Optional[int]`):最大並發請求數。
- **返回值**:`Dict[str, Any]`,包含基準測試結果和相關指標。
### 8. `check_goodput_args`
- **功能**:檢查並解析良好吞吐量(goodput)相關的參數。
- **參數**:
- `args` (`argparse.Namespace`):命令列參數對象。
- **返回值**:`Dict[str, float]`,良好吞吐量的配置字典。
### 9. `parse_goodput`
- **功能**:解析良好吞吐量的服務級別目標(SLO)。
- **參數**:
- `slo_pairs` (`List[str]`):包含 `"指標名稱:值"` 格式的 SLO 列表。
- **返回值**:`Dict[str, float]`,解析後的 SLO 字典。
### 10. `main`
- **功能**:主函數,解析參數並執行基準測試。
- **參數**:
- `args` (`argparse.Namespace`):命令列參數對象。
- **無返回值**:該函數執行基準測試流程,並可選擇性地保存結果。
### 使用說明
- 在使用這些函數前,請確保已經正確配置並載入所需的模型和 tokenizer。
- 可以根據不同的數據集選擇對應的 `sample_*_requests` 函數來生成請求數據。
- 使用 `benchmark` 函數執行基準測試,並根據需要調整參數,如後端類型、請求率、並發數等。
- 測試結束後,可使用 `calculate_metrics` 函數計算性能指標,或直接查看 `benchmark` 函數的返回結果。
<br>
<hr>
<br>
## 函數:`sample_hf_requests`
**功能**:
該函數從 Hugging Face 的數據集中抽樣請求,用於基準測試。它適用於包含對話(`conversations`)的數據集,並支持處理多模態內容(如圖像)。每個請求包含輸入文本(提示)、輸入長度、輸出長度,以及可選的多媒體內容。
**參數**:
- `dataset_path` (`str`):Hugging Face 數據集的 ID,傳遞給 `load_dataset` 函數。
- `dataset_subset` (`str`):Hugging Face 數據集的子集名稱(如果有)。
- `dataset_split` (`str`):Hugging Face 數據集的分割名稱(如 `"train"`、`"test"`)。
- `num_requests` (`int`):要抽樣的請求數量。
- `tokenizer` (`PreTrainedTokenizerBase`):用於分詞的 tokenizer。
- `random_seed` (`int`):隨機種子,用於數據集的洗牌。
- `fixed_output_len` (`Optional[int]`, 預設值:`None`):固定的輸出長度。如果為 `None`,則使用數據集中提供的輸出長度。
**返回值**:
- `List[Tuple[str, int, int, Optional[Dict[str, Collection[str]]]]]`:返回一個列表,每個元素是包含以下值的元組:
- `prompt` (`str`):輸入文本(提示)。
- `prompt_len` (`int`):輸入文本的標記長度。
- `output_len` (`int`):輸出文本的標記長度。
- `mm_content` (`Optional[Dict[str, Collection[str]]]`):可選的多媒體內容(如圖像)。
**執行步驟**:
1. **載入數據集**:
- 使用 `load_dataset` 函數從 Hugging Face 數據集庫載入指定的數據集:
```python
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)
```
- `dataset_path`:數據集的名稱或路徑。
- `name`:數據集的子集名稱。
- `split`:數據集的分割。
- `streaming=True`:啟用流式讀取,適合大型數據集。
2. **驗證數據結構**:
- 確保數據集包含 `'conversations'` 欄位,否則拋出錯誤:
```python
assert "conversations" in dataset.features, (
"HF Dataset must have 'conversations' column.")
```
3. **過濾和洗牌數據集**:
- 定義過濾函數,僅保留至少包含兩輪對話的數據:
```python
filter_func = lambda x: len(x["conversations"]) >= 2
```
- 使用隨機種子對數據集進行洗牌,並應用過濾:
```python
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
```
4. **初始化請求列表**:
- 創建一個空列表 `sampled_requests`,用於存儲抽樣的請求。
5. **迭代數據集並抽樣請求**:
- 遍歷 `filtered_dataset`,對每條數據進行處理,直到達到指定的請求數量 `num_requests`:
```python
for data in filtered_dataset:
if len(sampled_requests) == num_requests:
break
# 處理數據...
```
- **提取提示和回應**:
- 提取第一輪對話作為 `prompt`,第二輪對話作為 `completion`:
```python
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]
```
- **計算標記長度**:
- 使用 `tokenizer` 對 `prompt` 和 `completion` 進行分詞,並計算長度:
```python
prompt_token_ids = tokenizer(prompt).input_ids
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (len(completion_token_ids) if fixed_output_len is None
else fixed_output_len)
```
- **過濾過短或過長的序列**:
- 如果未指定 `fixed_output_len`,則過濾掉過短或過長的序列:
```python
if fixed_output_len is None and (prompt_len < 4 or output_len < 4):
continue # 跳過過短的序列
if fixed_output_len is None and (prompt_len > 1024 or prompt_len + output_len > 2048):
continue # 跳過過長的序列
```
- **處理多媒體內容(可選)**:
- 如果數據包含圖像,並且圖像是 `PIL.Image` 對象,則進行處理:
```python
if "image" in data and isinstance(data["image"], Image):
image = data["image"].convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
else:
mm_content = None
```
- **添加到請求列表**:
- 將處理後的數據添加到 `sampled_requests`:
```python
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
```
6. **返回結果**:
- 返回包含抽樣請求的列表:
```python
return sampled_requests
```
**總結**:
- 該函數從 Hugging Face 數據集中抽樣對話數據,用於基準測試。
- 支持處理多模態內容,如圖像,適用於測試多模態模型。
- 提供了選擇固定輸出長度的選項,方便控制生成的文本長度。
- 通過過濾和洗牌,確保抽樣的數據質量和多樣性。
**注意事項**:
- **數據集要求**:
- 數據集必須包含 `'conversations'` 欄位,且每條數據至少有兩輪對話。
- 如果需要處理圖像,數據中的 `'image'` 欄位應該包含 `PIL.Image` 對象。
- **多媒體內容處理**:
- 圖像被轉換為 Base64 編碼的字符串,並嵌入到數據中,適用於需要圖像輸入的模型。
- **過濾條件**:
- 如果未指定 `fixed_output_len`,則自動過濾掉過短或過長的序列,確保模型能夠正常處理。
**示例使用**:
```python
# 假設我們要從 Hugging Face 數據集 "my_dataset" 的 "train" 分割中抽樣 100 個請求
sampled_requests = sample_hf_requests(
dataset_path="my_dataset",
dataset_subset=None, # 如果沒有子集,設置為 None
dataset_split="train",
num_requests=100,
tokenizer=my_tokenizer,
random_seed=42,
fixed_output_len=50 # 可選,固定輸出長度為 50
)
# 每個請求包含:
# - prompt: 輸入文本
# - prompt_len: 輸入的標記長度
# - output_len: 輸出的標記長度
# - mm_content: 可選的多媒體內容
```
**實際應用**:
- 在基準測試中,使用該函數可以生成多樣化的請求,模擬實際應用場景。
- 適用於需要測試多輪對話能力的模型,以及需要處理圖像等多模態輸入的模型。
**可能的修改**:
- 自定義過濾條件:可以修改過濾條件,以適應不同的數據集和應用需求。
- 擴展多模態支持:如果需要處理其他類型的多媒體內容,可以擴展對其他媒體類型的支持。
<br>
<hr>
<br>
## 函數:`sample_random_requests`
> 函數的功能和執行過程。
**功能**:
該函數用於隨機生成一組請求,用於基準測試。每個請求包含一個隨機生成的輸入文本(根據指定的輸入長度和範圍),以及對應的輸入和輸出標記數。
**參數**:
- `prefix_len` (`int`):固定的前綴標記數,作為每個輸入的起始部分。
- `input_len` (`int`):每個請求的目標輸入標記數。
- `output_len` (`int`):每個請求的目標輸出標記數。
- `num_prompts` (`int`):要生成的請求數量。
- `range_ratio` (`float`):輸入/輸出長度的隨機範圍比例,用於在一定範圍內隨機化輸入和輸出長度。
- `tokenizer` (`PreTrainedTokenizerBase`):用於將標記 ID 轉換為文本的 tokenizer。
**返回值**:
- `List[Tuple[str, int, int]]`:返回一個列表,每個元素是包含以下三個值的元組:
- `prompt` (`str`):生成的隨機輸入文本。
- `int(prefix_len + input_lens[i])`:輸入文本的總標記數(包括前綴)。
- `int(output_lens[i])`:對應的輸出標記數。
**執行步驟**:
1. **生成固定前綴**:
- 使用 `np.random.randint(0, tokenizer.vocab_size, size=prefix_len)` 生成一個長度為 `prefix_len` 的隨機標記 ID 序列,範圍在 `[0, tokenizer.vocab_size)`。
- 將這些標記 ID 轉換為列表,存儲在 `prefix_token_ids` 中。
- 這些前綴標記將用於所有生成的輸入文本,保持固定不變。
2. **隨機生成輸入和輸出長度**:
- 使用 `np.random.randint(int(input_len * range_ratio), input_len + 1, size=num_prompts)` 生成一個長度為 `num_prompts` 的輸入長度列表 `input_lens`。每個輸入長度在 `[int(input_len * range_ratio), input_len]` 範圍內隨機選取。
- 使用類似的方法生成輸出長度列表 `output_lens`,範圍在 `[int(output_len * range_ratio), output_len]`。
3. **生成隨機偏移量**:
- 使用 `np.random.randint(0, tokenizer.vocab_size, size=num_prompts)` 生成一個長度為 `num_prompts` 的偏移量列表 `offsets`。這些偏移量將用於確保每個輸入文本的隨機性。
4. **生成輸入文本**:
- 初始化一個空列表 `input_requests`,用於存儲生成的請求。
- 遍歷每一個請求(使用 `for i in range(num_prompts)`):
- 對於第 `i` 個請求:
- 生成隨機標記 ID 序列:
- 使用列表生成式 `[(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]` 生成一個長度為 `input_lens[i]` 的隨機標記 ID 序列。這個序列中的每個標記 ID 都是根據偏移量、索引和詞彙表大小計算得出,確保隨機性並在詞彙表範圍內。
- 將前綴標記 ID 與生成的隨機標記 ID 序列拼接起來,形成完整的輸入標記 ID 序列。
- 使用 `tokenizer.decode` 將輸入標記 ID 序列轉換為文本 `prompt`。
- 計算輸入長度(包括前綴)`int(prefix_len + input_lens[i])`。
- 輸出長度為 `int(output_lens[i])`。
- 將 `(prompt, int(prefix_len + input_lens[i]), int(output_lens[i]), None)` 添加到 `input_requests` 列表中。
5. **返回結果**:
- 函數返回生成的 `input_requests` 列表,包含指定數量的請求,每個請求都有隨機生成的輸入文本、輸入長度和輸出長度。
**總結**:
- 該函數的主要目的是生成一組隨機的請求,用於測試模型在隨機輸入下的性能。
- 通過指定 `prefix_len`,可以為所有請求添加一個固定的前綴,模擬現實中常見的固定上下文。
- 使用 `range_ratio` 可以調整輸入和輸出長度的隨機程度。例如,當 `range_ratio` 設置為 1.0 時,輸入長度在 `[input_len * 1.0, input_len]` 之間隨機選取,即範圍縮小;當設置為較小的值(如 0.8)時,範圍擴大。
- 通過這種方式,可以生成具有多樣性和隨機性的請求,適用於模擬不同長度和內容的輸入,從而對模型進行全面的基準測試。
**注意事項**:
- 為了確保生成的標記 ID 在詞彙表範圍內,使用了取模操作 `% tokenizer.vocab_size`。
- 由於輸入文本是隨機生成的,可能不具有實際意義,但對於性能測試而言是足夠的。