# 基於大型醫療視覺語言模型的多模態RAG檢索系統
> 原文名稱:[MMED-RAG: VERSATILE MULTIMODAL RAG SYSTEM FOR MEDICAL VISION LANGUAGE MODELS](https://iclr.cc/virtual/2025/poster/28145)
## 問題
1. 不同醫療影像間分佈差距大,通用檢索可能引入不相關或錯誤資訊
2. 相似度分數會跟隨圖片與文字的對齊程度變動,固定的Top-K值無法適應不同複雜度的檢索,導致檢索結果包含無用資訊影響生成
3. 模型直接使用RAG造成圖文對齊產生偏差,文字檢索到上下文生成回答,而不考慮圖片資訊、過度依賴檢索資訊,直接抄作業、檢索到的不相關資訊的干擾、
## 先導實驗
1. top-K與相似性分數通常會在某個數量後急劇下降,表明固定K值會讓較低質量的上下文索引成爲輸入餵給VLM
2. 對齊實驗:
1. 在引入 RAG 後,根據加入噪聲的圖像檢索的上下文,55.08%的這些案例返回正確答案。這表明模型直接參考檢索的知識,而不考慮輸入圖像
2. 43.31%原本正確回答的問題在引入 RAG 後被錯誤回答,這表明來自不正確檢索信息的干擾
## 解法(三個模塊)

1. 領域感知檢索機制 (Domain-Aware Retrieval Mechanism):使用一個小資料集微調BiomedCLIP判斷該醫療影像屬於哪個種類(Ex:胸部X光、腦部CT......),並使用該類型的檢索器對RAG進行檢索
2. 適應檢索上下文選擇 (Adaptive Retrieved Context Selection):在 RAG 階段使用相似度分數來過濾低品質的資訊,以Gap Statistic算法為靈感,計算相似度比率是否在域值內,否則該上下文不是相關訊息
3. 基於 RAG 的偏好微調 (RAG-based Preference Fine-tuning):用抄作業比喻:我們想要模型:1.自主思考,自己知道答案能不直接複製就不複製 2.學會如何複製,不知道答案看別人的也不要抄錯 3.知道別人是錯的 如果別人的作業是錯的不如不要看自己猜
* DPO(Direct Preference Optimization):
一個完整的LLM訓練流程會經歷以下三個階段([ref](https://datasciocean.tech/deep-learning-core-concept/llm-fine-tuning-rlhf/)):
1. Pre-Training:透過大量的文字接龍訓練模型獲得理解語義的能力
2. Supervised Fine-Tuning:透過[回應,標籤]讓模型遵循指令
3. RLHF(Reinforcement Learning from Human Feedback):透過強化學習中的Reward model針對模型的回應打分,模型根據這些評分來更新自己,讓自己能夠得到最好的分數


但RLHF有兩個問題([ref](https://datasciocean.tech/paper-intro/direct-preference-optimization/)):
1. 需事先收集好的 Preference Dataset 訓練一個 Reward Model,要Train起來的關鍵就是Dataset品質要足夠好
2. 過 RL 演算法讓 LLM 透過最大化 Reward 並且受到KL散度的限制的過程來學習正確的輸出
而DPO就是為了不去使用強化學習,而是轉為使用監督式學習的方式訓練,免去建立Reward Model的麻煩
#### DPO loss function

從 DPO Loss Function 可以發現,給定一個 Preference Dataset 並抽樣出一個 Sample 包含 Prompt (x)、Winning Response (y_w) 與 Losing Response (y_l);基於 Prompt (x),模型 (PI_Theta) 必須學習讓 Winning Response (y_w) 出現的機率愈高,並讓 Losing Response (y_l) 出現的機率愈小。這個過程其實就相當於在原來的 RLHF 中,模型要學習輸出 Winning Response 來得到比較大的 Reward(避免輸出 Losing Response 因為得到的 Reward 也會比較小)。
此外,你還可以看到減號的前項與後項都分別除以參考模型 (PI_Ref) 生成相對應的 Response 的機率,這呼應了原來的 RLHF 中的 KL-Divergence 限制(不希望模型(PI_Theta)在學習的過程中,和原來的自己(PI_Ref)差異太大)。
* 偏好對:一個偏好對通常包含一個輸入 (x) 以及針對該輸入的兩個不同的回應:一個被偏好的回應 (yw 或 yw,o) 和一個被不偏好的回應 (yl 或 yl,o)
* 跨模態對齊 (Cross-Modality Alignment):關注模型如何有效地整合和利用來自不同模態的資訊,構建偏好對時,可能會比較模型在提供原始圖像時的正確回應,與提供一個與原始圖像相似度最低且加入了大量雜訊的圖像時的回應,透過訓練,模型被鼓勵即使在檢索到上下文時,仍需依賴和利用輸入的醫學影像來生成回應
雜訊圖片:選擇與目標圖像相似度最低的圖像x引入擴散噪聲
* 整體對齊 (Overall Alignment):關注模型生成的回應與真實情況 (ground truth) 之間的對齊,重於模型如何正確地理解和利用檢索到的外部知識,
* 當模型在沒有 RAG 時是錯誤的,但在有 RAG 時變為正確時,將有 RAG 的正確回應列為偏好的 [50, Algorithm 1 step 10-13]。這鼓勵模型學習如何利用檢索到的正確資訊,幫助模型更可靠地利用檢索到的知識,減少 CR 和 OR 錯誤
* 當模型在沒有 RAG 時是正確的,但在有 RAG 時變為錯誤時,將沒有 RAG 的正確回應列為偏好的,將有 RAG 的錯誤回應列為不偏好的 [16, 50, Algorithm 1 step 15-17]。這教導模型不要受到不相關或錯誤檢索資訊的干擾
* Doa1:加強模型對檢索知識的理解和推理能力,偏好根據原始圖像和檢索信息正確回答,不偏好未使用檢索的情況下根據圖像錯誤回答的情況
* Doa2:減少來自檢索知識的干擾,偏好根據原始圖像正確回答的情況,不偏好同時使用圖像和檢索到的信息獲得錯誤回答
* 演算法虛擬碼

1. ▷ Training Stage: 標示以下是訓練階段的步驟。
2. `Initialize Dcm with an empty set`: 初始化 Dcm 為一個空集。Dcm 將用於存放構建的「跨模態對齊」偏好對。
3. `foreach (xv, xt, y) ∈ D do`: 對於資料集 D 中的每一個樣本 (xv, xt, y) 進行迭代。
4. `Generate retrieved contexts with an assigned domain label xr ←RF(xv)(xv)`: 使用領域識別函式 F(xv) 確定的檢索器 RF(xv),基於圖像 xv 檢索相關的外部上下文 xr。這裡檢索結果被認為與特定領域相關。
5. Generate the noisy image x∗ v ← I(xv): 使用雜訊函式 I(·) 為當前樣本的原始圖像 xv 生成一個雜訊圖像 x*v。
6. ▷ Cross-Modality Alignment: 標示以下是構建用於改善「跨模態對齊」的偏好對的步驟。
7. if M(xv, (xt, xr)) = y and M(x∗ v, (xt, xr)) = y then: 條件判斷:檢查模型在使用原始圖像 (xv)、查詢 (xt) 和檢索到的上下文 (xr) 時能否生成正確答案 y (M(xv, (xt, xr)) = y),並且 在使用雜訊圖像 (x*v)、查詢 (xt) 和檢索到的上下文 (xr) 時也能生成正確答案 y (M(x*v, (xt, xr)) = y)。
>解釋:這種情況表示模型在圖像資訊被干擾時,仍然能夠僅依靠文字資訊(查詢 + RAG)得出正確答案。這正是「跨模態未對齊」的表現——模型可能忽略了圖像信息。構建偏好對的目的是教導模型更依賴真實圖像。
8. Select the preferred response yw,o1 ← y, dispreferred response yl,o1 ←M(x∗ v, (xt, xr)): 如果上述條件為真,則構建一個偏好對。將真實答案 y 設為偏好回應 (yw,o1)。將模型使用雜訊圖像得到的結果 (M(x*v, (xt, xr))) 設為不偏好回應 (yl,o1)。根據 Line 7 的條件,這裡的 yl,o1 實際上也是 y。
>解釋:即使兩個回應都是正確的 y,將它們分別標記為偏好和不偏好,並配合 DPO 損失,可以鼓勵模型在看到原始圖像時,生成 y 的概率(或對數概率)比看到雜訊圖像時更高。
9. Put {(xv, x∗ v, xt), yw,o1, yl,o1} into Dcm: 將構建的偏好對樣本 {(原始圖像 xv, 雜訊圖像 x*v, 文字查詢 xt), 偏好回應 yw,o1, 不偏好回應 yl,o1} 放入 Dcm 數據集。這個元組 (xv, x*v, xt) 代表用於訓練時的輸入 x。
10. ▷ Overall Alignment: 標示以下是構建用於改善「整體對齊」(特別是 RAG 利用的可靠性)的偏好對的步驟。
11. Initialize D1 oa and D2 oa with empty set: 初始化 D1oa 和 D2oa 為空集。這兩個集合將用於構建兩類「整體對齊」的偏好對。
12. if M(xv, (xt, xr)) = y and M(xv, xt) ̸= y then: 條件判斷 (第一類整體對齊 D1oa):檢查模型在使用原始圖像、查詢和檢索到的上下文時能生成正確答案 y (M(xv, (xt, xr)) = y),並且 在僅使用原始圖像和查詢(沒有 RAG)錯誤答案 (M(xv, xt) != y)。
>解釋:這種情況表明 RAG 引入的外部知識有助於模型得出正確答案。這類偏好對旨在鼓勵模型在檢索信息有益時,學會正確利用它 [對話歷史]。
13. Select the preferred response yw,o2 ← y, dispreferred response yl,o2 ←M(xv, xt): 如果上述條件為真,將真實答案 y(從有 RAG 的情況)設為偏好回應 (yw,o2),將模型在沒有 RAG 時生成的錯誤回應 (M(xv, xt)) 設為不偏好回應 (yl,o2)。
14. Put {(xv, xt), yw,o2, yl,o2} into D1 oa: 將構建的偏好對樣本 {(原始圖像 xv, 文字查詢 xt), 偏好回應 yw,o2, 不偏好回應 yl,o2} 放入 D1oa 數據集。
15. if M(xv, xt) = y and M(xv, (xt, xr)) ̸= y then: 條件判斷 (第二類整體對齊 D2oa):檢查模型在僅使用原始圖像和查詢(沒有 RAG)原始圖像、查詢和檢索到的上下文時生成了錯誤答案 (M(xv, (xt, xr)) != y)。
> 解釋:這種情況表明 RAG 引入的外部知識造成了干擾,導致模型從正確變為錯誤。這類偏好對旨在鼓勵模型在檢索信息有害時,學會不過度依賴它,轉而依賴其內部知識
16. Select the preferred response yw,o3 ← y, dispreferred response yl,o3 ←M(xv, (xt, xr)): 如果上述條件為真,將真實答案 y(從沒有 RAG 的情況)設為偏好回應 (yw,o3),將模型在有 RAG 時生成的錯誤回應 (M(xv, (xt, xr))) 設為不偏好回應 (yl,o3)。
17. Put {(xv, xt), yw,o3, yl,o3} into D2 oa: 將構建的偏好對樣本 {(原始圖像 xv, 文字查詢 xt), 偏好回應 yw,o3, 不偏好回應 yl,o3} 放入 D2oa 數據集。
18. Dpt = Dcm ∪ Doa, Doa = D1 oa ∪ D2 oa: 將構建好的所有偏好對 (Dcm、D1oa、D2oa) 合併成最終用於偏好微調的數據集 Dpt。Doa 是 D1oa 和 D2oa 的聯集。
19. foreach ((xv, x ∗ v, xt), yw,o, yl,o) ∈ Dpt do: 對於 Dpt 中的每一個偏好對樣本進行迭代。這裡的循環變數結構 (xv, x*v, xt) 似乎與 Line 14/17 存儲 Doa 樣本時的結構 (xv, xt) 不完全一致,這強化了前面提到的偽代碼結構上的潛在歧義。最可能的情況是,Dpt 包含兩種類型的樣本,循環邏輯需要區分處理。
20. Compute the losses Lpt following equation 4 and update πref: 根據 DPO 損失函數 (Equation 4 實際上是 Equation 10 和 Equation 5 的組合,都基於 Equation 18) 計算偏好損失 Lpt。然後使用這個損失來更新模型參數 πref。
> 注意:再次強調,根據標準 DPO 和論文的描述,被更新的應該是模型本身的參數 (πθ),而不是固定的參考模型 (πref 或 πo)。這裡寫更新 πref 很可能是筆誤,實際應更新 πθ。這個更新過程使用了 DPO 損失,目標是使模型在給定輸入 x 時,生成 yw,o 的概率高於生成 yl,o 的概率。這個訓練過程使用了 LoRA 等技術。
21. ▷ Inference Stage: 標示以下是推論(使用訓練好的模型進行預測)階段的步驟。
22. foreach test sample (xv, xt) do: 對於每一個測試樣本 (xv, xt) 進行迭代。
23. Select top-k retrieved contexts with an assigned domain label xr ←RF(xv)(xv): 與訓練階段類似,使用領域識別函式 F(xv) 確定的檢索器 RF(xv),為當前測試樣本檢索相關的外部上下文 xr。這裡通常會根據相似度分數選擇最佳數量的上下文 (top-k 或 Adaptive-k)。
24. Get the predictions of the model w/ RAG-PT p←M(xv, (xt, xr)): 將原始圖像 xv、文字查詢 xt 和檢索到的上下文 xr 作為輸入提供給經過 RAG-PT 微調後的 MMed-LVLM,得到模型的最終預測結果 p。這個結果應具有更高的事實準確性和可靠性。
* Loss Function:與DPO Loss 一致
* 證明跨模態對齊的改進
* 證明整體對齊的改進

## 實驗
### 參數設置
* 檢索器
* 視覺編碼器 (vision encoder):使用 ResNet-50
* 文字編碼器 (text encoder):使用 bio-BioClinicalBERT
* 優化器 (Optimizer): AdamW
* 學習率 (Learning Rate): $10^{-3}$
* 權重衰減 (Weight Decay): 10⁻²
* 批次大小 (Batch Size): 512
* 訓練週期 (Epochs): 360
* Epoch:=二階段,第一3Epoch,第二1Epoch
* 訓練時間:A6000*4 3小時 A100 20小時
* 偏好微調
* 骨幹模型 (Backbone Model): LLaVA-Med-1.5 7B
* 批次大小 (Batch Size): 32
* 偏好微調方法 (RAG-PT): LoRA
* 訓練時間A6000*4 4小時
### 實驗結果
#### 1. 評估VQA、報告生成能力
VQA:比原始的 Med-LVLM 提高了 18.5%,完全超越decoding,相比其他RAG方法平均提升2.8%,大部分占優

報告生成: 比原始的 Med-LVLM 提高了 69.1%,,完全超越decoding,相比其他RAG方法提升16.1%,大部分占優


#### 2.比較其他預訓練大語言模型
顯著超越了在大規模數據集上預訓練的醫療大型語言模型

#### 3.消融實驗
評估加上各模組後真實性的變化

#### 4.消融實驗-偏好數據集對 RAG-PT 的影響
評估加上跨模態(Dcm)、整體對齊(Doa1、Doa2)資料集後的真實性變化

#### 5.MMed-RAG 在減輕不對齊問題方面的有效性
複製參考(CR)率:模型可能會直接複製參考信息,CR 率降至 28.19%
過度依賴(OR)率:受 RAG 干擾影響的錯誤比例,,最初為 43.31%,在納入 MMed-RAG 後降至 8.38%

在上下文方面,將attention map可視化後可以發現在檢索到錯誤訊息時,圖片的參考變多,對RAG的參考變少

# 總結
1. 使用領域感知的檢索機制、自適應校準獲得最佳的檢索上下文品質
2. 偏好微調讓模型更容易關注圖片
3. 對vector RAG方法前中後階段皆做了改進
