# What is beam search and why beam search?
參考自 [如何通俗的理解beam search?](https://zhuanlan.zhihu.com/p/82829880)
搜尋法並不是一種模型,**而是針對模型來決定相對應輸出的 後處理演算法**。使用情境多應用於生成模型(Decoder model) 預測結果選擇相對應輸出時使用。
常用搜尋法有:
+ Exhaustive search 窮舉搜尋
+ Greedy search 貪心搜尋
+ Beam search 束搜尋
舉例來說,翻譯任務:
> 中文输入:"我" "恨" "你"
> 英文输出:"I" "H" "U"
## Exhaustive search
+ 最直觀的方法就是窮舉所有可能的輸出序列,3個時間步長,每個步長3種選擇。
```
I-I-I
I-I-H
I-I-U
I-H-I
I-H-H
I-H-U
I-U-I
I-U-H
I-U-U
H-I-I
H-I-H
H-I-U
H-H-I
H-H-H
H-H-U
H-U-I
H-U-H
H-U-U
U-I-I
U-I-H
U-I-U
U-H-I
U-H-H
U-H-U
U-U-I
U-U-H
U-U-U
```
+ 計算複雜度太高,當輸出詞典稍微大一點根本無法使用。
## Greedy search
+ 貪心算法在翻譯每個字的時候,直接選擇條件概率最大的候選值作為當前最優。

+ 貪心算法每一步選擇中都採取在當前狀態下最好或最優的選擇,通過這種局部最優策略期望產生全局最優解,效率最高但並不能保證最終的結果一定是全局最優的。
## Beam search
+ beam search是對greedy search的一個改進算法。相對greedy search擴大了搜索空間。
+ beam search有一個超參數beam size(束寬),設為 k。第一個時間步長,選取當前條件概率最大的 k 個詞,當做候選輸出序列的第一個詞。之後的每個時間步長,基於上個步長的輸出序列,挑選出所有組合中條件概率最大的 k 個,作為該時間步長下的候選輸出序列。始終保持 k 個候選。最後從 k 個候選中挑出最優的。



+ beam search不保證全局最優,但是比greedy search搜索空間更大,一般結果比greedy search要好。
+ greedy search 可以看做是 beam size = 1時的 beam search。
# Implement
參考自 [How to Implement a Beam Search Decoder for Natural Language Processing](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/)
+ Greedy Search
```python
from numpy import array
from numpy import argmax
# greedy decoder
def greedy_decoder(data):
# index for largest probability each row
return [argmax(s) for s in data]
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]
]
data = array(data)
# decode sequence
result = greedy_decoder(data)
print(result)
```
```python
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
```
+ Beam Search
```python
from math import log
from numpy import array
from numpy import argmax
# beam search
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
# expand each current candidate
print(f"Step {id_} in searching: {sequences}\n")
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - log(row[j])]
all_candidates.append(candidate)
# order all candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
# select k best
sequences = ordered[:k]
return sequences
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]
]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
print(seq)
```
```python
# Each step in searching...
Step 0 in searching: [[[], 0.0]]
Step 1 in searching: [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]
Step 2 in searching: [[[4, 0], 1.3862943611198906], [[4, 1], 1.6094379124341003], [[4, 2], 1.8971199848858813]]
Step 3 in searching: [[[4, 0, 4], 2.0794415416798357], [[4, 0, 3], 2.3025850929940455], [[4, 0, 2], 2.5902671654458267]]
Step 4 in searching: [[[4, 0, 4, 0], 2.772588722239781], [[4, 0, 4, 1], 2.995732273553991], [[4, 0, 4, 2], 3.283414346005772]]
Step 5 in searching: [[[4, 0, 4, 0, 4], 3.4657359027997265], [[4, 0, 4, 0, 3], 3.6888794541139363], [[4, 0, 4, 0, 2], 3.9765615265657175]]
Step 6 in searching: [[[4, 0, 4, 0, 4, 0], 4.1588830833596715], [[4, 0, 4, 0, 4, 1], 4.382026634673881], [[4, 0, 4, 0, 4, 2], 4.669708707125663]]
Step 7 in searching: [[[4, 0, 4, 0, 4, 0, 4], 4.852030263919617], [[4, 0, 4, 0, 4, 0, 3], 5.075173815233827], [[4, 0, 4, 0, 4, 0, 2], 5.362855887685607]]
Step 8 in searching: [[[4, 0, 4, 0, 4, 0, 4, 0], 5.545177444479562], [[4, 0, 4, 0, 4, 0, 4, 1], 5.768320995793772], [[4, 0, 4, 0, 4, 0, 4, 2], 6.056003068245553]]
Step 9 in searching: [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 6.238324625039508], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 6.461468176353717], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 6.749150248805498]]
```
```python
# Results:
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 6.931471805599453]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 7.154615356913663]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 7.154615356913663]
```
### How about applying both methods in auto regression?
+ greedy search
```python
def greedy_decode(model, input_seq, max_length=50):
output_seq = [input_seq]
for _ in range(max_length):
logits = model(output_seq)
next_token = argmax(logits)
output_seq.append(next_token)
if next_token == EOS_TOKEN:
break
return output_seq
```
+ beam search
```python
def beam_search(model, input_seq, k=3, max_length=50):
beams = [(input_seq, 0)] # each beam is (sequence, score)
for _ in range(max_length):
new_beams = []
for seq, score in beams:
logits = model(seq)
probs = softmax(logits)
topk_tokens, topk_probs = get_topk(probs, k)
for token, prob in zip(topk_tokens, topk_probs):
new_seq = seq + [token]
new_score = score + log(prob) # Use log for numerical stability
new_beams.append((new_seq, new_score))
# Sort and keep the top k sequences
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:k]
# Stop early if all top sequences end with <eos>
if all(seq[-1] == EOS_TOKEN for seq, _ in beams):
break
return beams[0] # return the top sequence
```
+ Ours
```python
def greedy_auto_regression(
self, pixel_values, dummy_input_ids, max_length=None, eos_id=None
):
# TODO(Weber): this step may have to refactor if implement beam search.
result = []
for _ in range(cast(int, max_length) - 1):
with torch.no_grad():
logits = self.model(pixel_values, dummy_input_ids)
# (seq_len, batch, vocab_size) to (batch, seq_len, vocab_size)
logits = logits.permute(1, 0, 2)
next_token_logits = logits[:, -1]
next_token = torch.argmax(next_token_logits)
# update the dummy_input_ids for greedy decoding.
dummy_input_ids = torch.concat(
[dummy_input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1
)
# concat the result last part of logits.
result.append(next_token)
if next_token.int() == eos_id:
break
result = torch.stack(result)
return result
def beam_auto_regression(
self, pixel_values, dummy_input_ids, max_length=None, eos_id=None, k=3
):
def get_topk(probs, k):
topk_probs, topk_indices = torch.topk(probs, k)
return topk_indices, topk_probs
beams = [([], 0)] # each beam is (sequence, score)
for _ in range(cast(int, max_length) - 1):
new_beams = []
for seq, score in beams:
with torch.no_grad():
logits = self.model(pixel_values, dummy_input_ids)
# We need softmax here
probs = self.softmax(logits)
topk_tokens, topk_probs = get_topk(probs, k)
for token, prob in zip(topk_tokens, topk_probs):
new_seq = seq + [token]
new_score = score + math.log(prob) # Use log for numerical stability
new_beams.append((new_seq, new_score))
# Sort and keep the top k sequences
beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:k]
# Stop early if all top sequences end with <eos>
if all(seq[-1] == eos_id for seq, _ in beams):
break
return beams[0][0] # return the top sequence
```