changed 6 years ago
Linked with GitHub

Matching Networks for One Shot Learning

Code

Abstract

  • In this work, we employ ideas from metric learning based on deep neural features and from recent advances that augment neural networks with external memories.
  • Our framework learns a network that maps a small labelled support set and an unlabelled example to its label, obviating the need for fine-tuning to adapt to new class types.

1 Introduction

  • The novelty of our work is two-fold: at the modeling level, and at the training procedure.
    • modeling level: We propose Matching Nets (MN), a neural network which uses recent advances in attention and memory that enable rapid learning.
    • training procedure: our training procedure is based on a simple machine learning principle: test and train conditions must match. Thus to train our network to do rapid learning, we train it by showing only a few examples per class, switching the task from minibatch to minibatch, much like how it will be tested when presented with a few examples of a new task.
  1. 一些 non-parametric model (例如 KNN) 可以快速學習新 sample。本文要結合 parametric model (DL),和 non-parametric model。DL 的 sample 用完即丟,KNN 的 sample 會保留。
  2. novelty: 提出 matching network & 新的 training procedure

2 Model

我們用來解決 one-shot learning 的 non-parametric approach 是基於以下兩點:

  1. our model architecture follows recent advances in neural networks augmented with memory. Given a (small) support set \(S\), our model defines a function \(c_S\) (or classifier) for each \(S\), i.e. a mapping \(S \rightarrow c_S(\cdot)\).
    • 強調對於每個 support set,都可以 map 到一個不同的 classifier,即 \(S\rightarrow c_S(\cdot)\) 的 mapping
  2. we employ a training strategy which is tailored for one-shot learning from the support set \(S\).

2.1 model architecture

  • Our contribution is to cast the problem of one-shot learning within the set-to-set framework [26].
    • [26] Order matters: Sequence to sequence for sets. 2015
  • The key point is that when trained, Matching Networks are able to produce sensible test labels for unobserved classes without any changes to the network.(???看不懂)
  • More precisely, we wish to map from a (small) support set \(S\) to a classifier \(c_S\). We define the mapping \(S\rightarrow c_S(\hat x)\) to be \(P(\hat y| \hat x, S)\).
  • Our model in its simplest form computes \(\hat y\) as follows:
    • \(\hat y = \sum_\limits{i=1}^k a(\hat x, x_i)y_i \tag 1\)
    • eq. 1 essentially describes the output for a new class as a linear combination of the labels in the support set.
    • (??? 這段有看沒有懂) Where the attention mechanism a is a kernel on \(X \times X\), then (1) is akin to a kernel density estimator(KDE). Where the attention mechanism is zero for the b furthest xi from ^x according to some distance metric and an appropriate constant otherwise, then (1) is equivalent to ‘\(k - b\)’-nearest neighbours (although this requires an extension to the attention mechanism that we describe in Section 2.1.2). Thus (1) subsumes both KDE and kNN methods.
    • Another view of (eq. 1) is where \(a\) acts as an attention mechanism and the \(y_i\) act as memories bound to the corresponding \(x_i\). In this case we can understand this as a particular kind of associative memory where, given an input, we “point” to the corresponding example in the support set, retrieving its label
  • However, unlike other attentional memory mechanisms [2], (1) is non-parametric in nature: as the support set size grows, so does the memory used. Hence the functional form defined by the classifier \(c_S(\hat x)\) is very flexible and can adapt easily to any new support set.

靈感來自於

  • attention
  • memory networks
    • 與 memory network 的聯結,個人理解為 support set 的圖片就像 document 中的句子,會被 LSTM 提取 information,而 query set 的圖片就像 question vector,用來對 support set 中的圖片做 attention。應該另外也有做 hopping,不是很確定。
  • pointer networks

prediction \(\hat y = f(D^{train}, x^{test})\)
機率表示為 \(P(\hat y|\hat x, S)\)\(S = \{(x_i, y_i)\}_{i=1}^k\)

Matching Net 將該模型表示為:\[\hat y = \sum_\limits{i=1}^k\alpha(\hat x, x_i)y_i\]

  • 預測結果 \(\hat y\) 被看成 support set 樣本中 label 的 linear combination
    • 若將 \(a(\hat x, x_i)\) 作為 kernel function,則該 model 可近似為 DL 做 embedding 層,KDE(??? kernel density estimator) 做分類
    • 若將 \(a(\hat x, x_i)\) 作為 0-1 function,則該 model 可近似為 DL 做 embedding 層,KNN做分類

2.1.1 Attention Kernel

  • the simplest form is to use the softmax over the cosine distance \(c\), i.e., \(a(\hat x, x_i) = e^{c(f(\hat x), g(x_i))} / \sum_{j=1}^k e^{c(f(\hat x), g(x_j))}\)
    • \(f\) embed \(\hat x\); and \(g\) embed \(x_i\)
  • (看無???) the objective that we are trying to optimize is precisely aligned with multi-way, one-shot classification, and thus we expect it to perform better than its counterparts.

本文附予 \(a(\hat x, x_i)\) 新的形式:看作 attention kernel,model 預測結果就是 support set 中 attention 最多的圖片的 label。常見 attention kernel 是 cosine similarity 加上 softmax:\[a(\hat x, x_i) = \dfrac{e^{c(f(\hat x), g(x_i))}}{\sum_{j=1}^k e^{c(f(\hat x), g(x_j))}}\]

  • \(f, g\) 是兩個 embedding function,可以是 NN

2.1.2 Full Context Embeddings

  • Despite the classification strategy is fully conditioned on the whole support set through \(P(\cdot |\hat x, S)\). Furthermore, \(S\) should be able to modify how we embed the test image \(\hat x\) through \(f\). so
    • \(g\) becomes \(g(x_i, S)\)
      • to embed data in support set
    • \(f(\hat x, S) = attLSTM(f'(\hat x), g(S), K)\)
      • to embed data in query set
      • \(f'(\hat x)\) are the features (e.g., derived from a CNN) which are input to the LSTM
      • \(g(S)\) is the set over which we attend, embedded with g

embedding vector \(g(x^i) \leftarrow g(x^i, S)\),嵌入函數的輸出同時由對應的 \(x^i\) 和整個support set有關。support set是每次隨機選取的,嵌入函數同時考慮 support set 和 \(x^i\) 可以消除隨機選擇造成的差異性。類似機器翻譯中 word 和 context 的關係,\(S\) 可以看做是 \(x^i\) 的 context,所以本文在嵌入函數中用到了 LSTM。

support set 中的 \(x^i\) 經過多層 convolution 後,再經過一層 bi-LSTM

The Fully Conditional Embedding \(f\)

\(\hat h_k, c_k = LSTM(f'(\hat x), [h_{k-1}, r_{k-1}], c_{k-1})\\ h_k = \hat h_k + f'(\hat x)\\ r_{k-1} = \sum_{i=1}^{|S|}a(h_{k-1}, g(x_i))g(x_i)\\ a(h_{k-1}, g(x_i)) = softmax(h_{k-1}^T g(x_i))\)

  • \(g'(x_i)\) 是 CNN 抽出的 feature

因此再回顧一次 structure

以下個人理解

  • \(g_\theta\)\(g(x_i, S)\) 由整個 support set 決定 embedding function,然後 input \(x_i\)
  • \(g_\theta\) 右邊四個扁扁:embedding
  • 圈圈叉叉:用 cosine similarity 計算相似度
  • 右邊四個點:attention weight
  • \(f_\theta\)
  • \(f_\theta\) 上面的扁扁:embedding
  • 右邊四個正方形:label
  • 最後的正方形:答案,由 support set 的 label 利用 attention weight 做 linear combination,最後 argmax 得到答案

2.2 Training Strategy

  • 一個 batch 包含多個 task
  • 一個 task 包括一個 support set 跟一個 test example
  • 一個 support set 包括多個 sample
  • support set 中洽有一個 sample 和 test sample 同 class

to do N-way K-shot

每個 "episode" example

  1. 從 training data 中 sample 一個 task T,選擇 N 個 class,每個 class 選 K 個 examples
  2. To form one episode sample a label set L (e.g. {cats, dogs}) and then use L to sample the support set S and a batch B of examples to evaluate loss on.

參考

tags: fewshot learning
Select a repo