# 【論文筆記】Guiding Text-to-Image Diffusion Model Towards Grounded Generation 論文連結: https://arxiv.org/abs/2301.05221 ## Overview 這篇論文提出了一個方法,擴展 Stable Diffusion model 來完成 object grounding,也就是在生成 image 的同時,也針對文字 prompt 描述的物體進行分割。他們主要的貢獻包含: 1. 建立了一個生成 dataset 的流程,用以訓練他們提出的模型 2. 提出了一個架構,可以同時生成 image 和把文字裡提到的物體分割出來 3. 經過 evaluate 之後,證實這個架構可以分割在訓練階段沒看過的類別 ![](https://hackmd.io/_uploads/By6pyrzr2.png) 下面會先介紹他們使用的模型架構,接著才介紹他們生成 dataset 的方法,最後再說明一些實驗的結果。 ## Architecture 首先先簡單講一下整個任務目標的大方向。作者取用了 Stable Diffusion model,目標是構建一個模型,可以輸入 noise 和文字 prompt,輸出 image 和 mask,如下所示: $$ \{\mathcal{I, m}\} = \Phi_{\text{diffusion}^+}(\epsilon, \mathcal{y}) $$ 其中 $\Phi_{\text{diffusion}^+}$ 代表他們利用 Stable Diffusion 進行擴展後的 model。 下圖是整個 grounding module 的架構,可以看到除了 diffusion model 之外,主要包含三個部分,分別是 visual encoder, text encoder 和 fusion module。 ![](https://hackmd.io/_uploads/SypeBrzB3.png =500x) ### Text Encoder Text encoder 採用的是 Stable Diffusion pre-trained text encoder(為 CLIP 的 encoder)。給定 text prompt 作為 input 後,它會生成對應的 embedding: $$ \mathcal{E}_{\text{obj}_i} = \Phi_\text{t-enc}(g(y_i)) $$ ### Visual Encoder Visual encoder 的輸入是在 denoising timestep $t=5$ 時,Stable Diffusion 裡 UNet 每個 layer 輸出的 intermediate features $\{f^1_i, ..., f^n_i\}$,他們把這些 feature 稱作 visual representation。選用 $t=5$ 的原因是因為經過 ablation study 試驗不同 timestep 的輸出之後,他們發現 $t=5$ 時的輸出可以有最好的效果。 這些 features 輸入 visual encoder 後會輸出一個 fused visual feature: $$ \mathcal{\hat{F}}_i = \Phi_\text{v-enc} (\{f^1_i, ..., f^n_i\}) $$ ![](https://hackmd.io/_uploads/Sygj9SMrn.png =500x) ### Fusion Module Fusion module 選用的是一個 3 層的 transformer decoder,把 text embedding 轉化成 transformer 的 Query,visual feature 轉化成 Key 和 Value 輸入到 transformer 之後,會輸出一個 segmentation embedding,最後會再經過一個 MLP 並計算和 visual feature 的內積得到 segmentation mask: \begin{align} &\mathcal{E}_{\text{seg}_i} = \Phi_\text{transfromer-D}(W^Q \cdot \mathcal{E}_{\text{obj}_i}, W^K \cdot \mathcal{\hat{F}}_i, W^V \cdot \mathcal{\hat{F}}_i) \\ &m_i = \mathcal{\hat{F}}_i \cdot [\Phi_{\text{MLP}}(\mathcal{E}_{\text{seg}_i})]^T \end{align} ### Training 先假設已經有一個 training set,裡面包含 {visual feature, segmentation, text prompt} pairs,則訓練的 loss 定義為 $$ \mathcal{L} = -\frac{1}{N} \sum^N_{i=1} [m_i^{gt} \cdot \log(\sigma(m_i)) + (1 - m_i^{gt}) \cdot \log (\sigma(1 - m_i))] $$ 其中 $\sigma(\cdot)$ 為 Sigmoid function。訓練的過程中,text encoder 是凍住的,只訓練 visual encoder 和 fusion module。 ## Dataset Collection 從前面 training 的過程,可以發現需要的 training set 當中每一筆資料都是一個 triplet,包含 {visual feature, segmentation, text prompt},因此這裡的目標就是要建立一個可以生成很多 triplet 的架構。 首先他們準備一些常見類別的詞彙(例如 PASCAL VOC 這個 dataset 裡就有 20 種常見的類別),先把這些詞彙分成兩個子集 $\mathcal{C}_\text{seen}$ 和 $\mathcal{C}_\text{unseen}$,只用 $\mathcal{C}_\text{seen}$ 裡有出現的類別來建立 training set。隨機抽取 $\mathcal{C}_\text{seen}$ 裡面的一到兩個類別建立一個 text prompt(例如 a photograph of a dog and cat),經過 Stable Diffusion model 得到 visual feature 和生成的 image,再把生成的 image 輸入至 pre-trained Mask R-CNN 得到 segmentation mask,如此就獲得了一筆訓練所需的 triplet {visual feature, segmentation, text prompt}。重複以上步驟多次,就可以獲得很多筆這樣的 triplet。 ![](https://hackmd.io/_uploads/B1_2lL7r2.png) ## Experiments ### Evaluation on Grounded Generation 第一個實驗測試他們提出的架構在 grounding 上的表現。他們使用 PASCAL VOC 和 MS-COCO 兩個 dataset 裡的類別,用前面提到的方法建立了兩種 training sets,分別訓練之後再做 testing,結果如下表: ![](https://hackmd.io/_uploads/SkMn4PMrh.png) 上表使用的指標是 mIoU (%),one 或 two 代表的是產生 text prompt 時用的類別數量,Split 1 到 3 則是代表三種不同 $\mathcal{C}_\text{seen}$ 和 $\mathcal{C}_\text{unseen}$ 的分法。 從表中可以看到他們的方法比 DAAM 這個使用 unsupervised leaning 的方法還要來得好。 ![](https://hackmd.io/_uploads/BJRKwDzrh.png) 上圖是一些輸出的結果,其中作者特別指出說 sofa, car, hot dog 和 bear 是在訓練過程中沒有看過的類別,但也可以被正確分割出來。 ### Open-vocabulary Segmentation 為了進一步驗證他們的 grounding 方法是有效的,這裡又做了另一個實驗,利用他們提出的 grounding module 生成人造的 image-segmentation dataset,接著嘗試用這個 dataset 來訓練一個 semantic segmentation model,並測試表現如何。 首先他們用 PASCAL VOC 裡的 20 種類別,透過他們的 guided Stable Diffusion 生成出一萬組人造的 image-segmentation pairs,接著用這些資料訓練 MaskFormer,得到 semantic segmentation model。 ![](https://hackmd.io/_uploads/B1cRpPMr2.png =500x) 上半部分的 zero-shot segmentation methods 是訓練在 PASCAL-VOC training set 上。從表中可以看到,訓練在人造 dataset 上的 MaskFormer 比大部分的 zero-shot segmentation 方法表現來得好,和目前的 state-of-the-art ZegFormer 則是表現還有一些差異,但 ZegFormer 訓練在 real image 上,而他們訓練 MaskFormer 是用生成出來的人造 images。