---
# System prepended metadata

title: 【論文筆記】Guiding Text-to-Image Diffusion Model Towards Grounded Generation
tags: [Diffusion, DL, Paper]

---

# 【論文筆記】Guiding Text-to-Image Diffusion Model Towards Grounded Generation

論文連結：
https://arxiv.org/abs/2301.05221

## Overview

這篇論文提出了一個方法，擴展 Stable Diffusion model 來完成 object grounding，也就是在生成 image 的同時，也針對文字 prompt 描述的物體進行分割。他們主要的貢獻包含：
1. 建立了一個生成 dataset 的流程，用以訓練他們提出的模型
2. 提出了一個架構，可以同時生成 image 和把文字裡提到的物體分割出來
3. 經過 evaluate 之後，證實這個架構可以分割在訓練階段沒看過的類別
![](https://hackmd.io/_uploads/By6pyrzr2.png)


下面會先介紹他們使用的模型架構，接著才介紹他們生成 dataset 的方法，最後再說明一些實驗的結果。

## Architecture

首先先簡單講一下整個任務目標的大方向。作者取用了 Stable Diffusion model，目標是構建一個模型，可以輸入 noise 和文字 prompt，輸出 image 和 mask，如下所示：
$$
\{\mathcal{I, m}\} = \Phi_{\text{diffusion}^+}(\epsilon, \mathcal{y})
$$

其中 $\Phi_{\text{diffusion}^+}$ 代表他們利用 Stable Diffusion 進行擴展後的 model。

下圖是整個 grounding module 的架構，可以看到除了 diffusion model 之外，主要包含三個部分，分別是 visual encoder, text encoder 和 fusion module。

![](https://hackmd.io/_uploads/SypeBrzB3.png =500x)

### Text Encoder
Text encoder 採用的是 Stable Diffusion pre-trained text encoder（為 CLIP 的 encoder）。給定 text prompt 作為 input 後，它會生成對應的 embedding：
$$
\mathcal{E}_{\text{obj}_i} = \Phi_\text{t-enc}(g(y_i))
$$

### Visual Encoder

Visual encoder 的輸入是在 denoising timestep $t=5$ 時，Stable Diffusion 裡 UNet 每個 layer 輸出的  intermediate features $\{f^1_i, ..., f^n_i\}$，他們把這些 feature 稱作 visual representation。選用 $t=5$ 的原因是因為經過 ablation study 試驗不同 timestep 的輸出之後，他們發現 $t=5$ 時的輸出可以有最好的效果。

這些 features 輸入 visual encoder 後會輸出一個 fused visual feature：
$$
\mathcal{\hat{F}}_i = \Phi_\text{v-enc} (\{f^1_i, ..., f^n_i\})
$$

![](https://hackmd.io/_uploads/Sygj9SMrn.png =500x)

### Fusion Module

Fusion module 選用的是一個 3 層的 transformer decoder，把 text embedding 轉化成 transformer 的 Query，visual feature 轉化成 Key 和 Value 輸入到 transformer 之後，會輸出一個 segmentation embedding，最後會再經過一個 MLP 並計算和 visual feature 的內積得到 segmentation mask：
\begin{align}
&\mathcal{E}_{\text{seg}_i} = \Phi_\text{transfromer-D}(W^Q \cdot \mathcal{E}_{\text{obj}_i}, W^K \cdot \mathcal{\hat{F}}_i, W^V \cdot \mathcal{\hat{F}}_i) \\
&m_i = \mathcal{\hat{F}}_i \cdot [\Phi_{\text{MLP}}(\mathcal{E}_{\text{seg}_i})]^T
\end{align}


### Training

先假設已經有一個 training set，裡面包含 {visual feature, segmentation, text prompt} pairs，則訓練的 loss 定義為 
$$
\mathcal{L} = -\frac{1}{N} \sum^N_{i=1} [m_i^{gt} \cdot \log(\sigma(m_i)) + (1 - m_i^{gt}) \cdot \log (\sigma(1 - m_i))]
$$

其中 $\sigma(\cdot)$ 為 Sigmoid function。訓練的過程中，text encoder 是凍住的，只訓練 visual encoder 和 fusion module。

## Dataset Collection

從前面 training 的過程，可以發現需要的 training set 當中每一筆資料都是一個 triplet，包含 {visual feature, segmentation, text prompt}，因此這裡的目標就是要建立一個可以生成很多 triplet 的架構。

首先他們準備一些常見類別的詞彙（例如 PASCAL VOC 這個 dataset 裡就有 20 種常見的類別），先把這些詞彙分成兩個子集 $\mathcal{C}_\text{seen}$ 和 $\mathcal{C}_\text{unseen}$，只用 $\mathcal{C}_\text{seen}$ 裡有出現的類別來建立 training set。隨機抽取 $\mathcal{C}_\text{seen}$ 裡面的一到兩個類別建立一個 text prompt（例如 a photograph of a dog and cat），經過 Stable Diffusion model 得到 visual feature 和生成的 image，再把生成的 image 輸入至 pre-trained Mask R-CNN 得到 segmentation mask，如此就獲得了一筆訓練所需的 triplet {visual feature, segmentation, text prompt}。重複以上步驟多次，就可以獲得很多筆這樣的 triplet。

![](https://hackmd.io/_uploads/B1_2lL7r2.png)

## Experiments

### Evaluation on Grounded Generation

第一個實驗測試他們提出的架構在 grounding 上的表現。他們使用 PASCAL VOC 和 MS-COCO 兩個 dataset 裡的類別，用前面提到的方法建立了兩種 training sets，分別訓練之後再做 testing，結果如下表：

![](https://hackmd.io/_uploads/SkMn4PMrh.png)

上表使用的指標是 mIoU (%)，one 或 two 代表的是產生 text prompt 時用的類別數量，Split 1 到 3 則是代表三種不同 $\mathcal{C}_\text{seen}$ 和 $\mathcal{C}_\text{unseen}$ 的分法。

從表中可以看到他們的方法比 DAAM 這個使用 unsupervised leaning 的方法還要來得好。

![](https://hackmd.io/_uploads/BJRKwDzrh.png)

上圖是一些輸出的結果，其中作者特別指出說 sofa, car, hot dog 和 bear 是在訓練過程中沒有看過的類別，但也可以被正確分割出來。

### Open-vocabulary Segmentation

為了進一步驗證他們的 grounding 方法是有效的，這裡又做了另一個實驗，利用他們提出的 grounding module 生成人造的 image-segmentation dataset，接著嘗試用這個 dataset 來訓練一個 semantic segmentation model，並測試表現如何。

首先他們用 PASCAL VOC 裡的 20 種類別，透過他們的 guided Stable Diffusion 生成出一萬組人造的 image-segmentation pairs，接著用這些資料訓練 MaskFormer，得到 semantic segmentation model。

![](https://hackmd.io/_uploads/B1cRpPMr2.png =500x)

上半部分的 zero-shot segmentation methods 是訓練在 PASCAL-VOC training set 上。從表中可以看到，訓練在人造 dataset 上的 MaskFormer 比大部分的 zero-shot segmentation 方法表現來得好，和目前的 state-of-the-art ZegFormer 則是表現還有一些差異，但 ZegFormer 訓練在 real image 上，而他們訓練 MaskFormer 是用生成出來的人造 images。

