# What is beam search and why beam search? 參考自 [如何通俗的理解beam search?](https://zhuanlan.zhihu.com/p/82829880) 搜尋法並不是一種模型,**而是針對模型來決定相對應輸出的 後處理演算法**。使用情境多應用於生成模型(Decoder model) 預測結果選擇相對應輸出時使用。 常用搜尋法有: + Exhaustive search 窮舉搜尋 + Greedy search 貪心搜尋 + Beam search 束搜尋 舉例來說,翻譯任務: > 中文输入:"我" "恨" "你" > 英文输出:"I" "H" "U" ## Exhaustive search + 最直觀的方法就是窮舉所有可能的輸出序列,3個時間步長,每個步長3種選擇。 ``` I-I-I I-I-H I-I-U I-H-I I-H-H I-H-U I-U-I I-U-H I-U-U H-I-I H-I-H H-I-U H-H-I H-H-H H-H-U H-U-I H-U-H H-U-U U-I-I U-I-H U-I-U U-H-I U-H-H U-H-U U-U-I U-U-H U-U-U ``` + 計算複雜度太高,當輸出詞典稍微大一點根本無法使用。 ## Greedy search + 貪心算法在翻譯每個字的時候,直接選擇條件概率最大的候選值作為當前最優。 ![](https://hackmd.io/_uploads/rJ5EbSwj2.png) + 貪心算法每一步選擇中都採取在當前狀態下最好或最優的選擇,通過這種局部最優策略期望產生全局最優解,效率最高但並不能保證最終的結果一定是全局最優的。 ## Beam search + beam search是對greedy search的一個改進算法。相對greedy search擴大了搜索空間。 + beam search有一個超參數beam size(束寬),設為 k。第一個時間步長,選取當前條件概率最大的 k 個詞,當做候選輸出序列的第一個詞。之後的每個時間步長,基於上個步長的輸出序列,挑選出所有組合中條件概率最大的 k 個,作為該時間步長下的候選輸出序列。始終保持 k 個候選。最後從 k 個候選中挑出最優的。 ![](https://hackmd.io/_uploads/SJTyzSDin.png) ![](https://hackmd.io/_uploads/rJfgMrPsh.png) ![](https://hackmd.io/_uploads/S1dxzHPi2.png) + beam search不保證全局最優,但是比greedy search搜索空間更大,一般結果比greedy search要好。 + greedy search 可以看做是 beam size = 1時的 beam search。 # Implement 參考自 [How to Implement a Beam Search Decoder for Natural Language Processing](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/) + Greedy Search ```python from numpy import array from numpy import argmax # greedy decoder def greedy_decoder(data): # index for largest probability each row return [argmax(s) for s in data] # define a sequence of 10 words over a vocab of 5 words data = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1] ] data = array(data) # decode sequence result = greedy_decoder(data) print(result) ``` ```python [4, 0, 4, 0, 4, 0, 4, 0, 4, 0] ``` + Beam Search ```python from math import log from numpy import array from numpy import argmax # beam search def beam_search_decoder(data, k): sequences = [[list(), 0.0]] # walk over each step in sequence for row in data: all_candidates = list() # expand each current candidate print(f"Step {id_} in searching: {sequences}\n") for i in range(len(sequences)): seq, score = sequences[i] for j in range(len(row)): candidate = [seq + [j], score - log(row[j])] all_candidates.append(candidate) # order all candidates by score ordered = sorted(all_candidates, key=lambda tup:tup[1]) # select k best sequences = ordered[:k] return sequences # define a sequence of 10 words over a vocab of 5 words data = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1] ] data = array(data) # decode sequence result = beam_search_decoder(data, 3) # print result for seq in result: print(seq) ``` ```python # Each step in searching... Step 0 in searching: [[[], 0.0]] Step 1 in searching: [[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]] Step 2 in searching: [[[4, 0], 1.3862943611198906], [[4, 1], 1.6094379124341003], [[4, 2], 1.8971199848858813]] Step 3 in searching: [[[4, 0, 4], 2.0794415416798357], [[4, 0, 3], 2.3025850929940455], [[4, 0, 2], 2.5902671654458267]] Step 4 in searching: [[[4, 0, 4, 0], 2.772588722239781], [[4, 0, 4, 1], 2.995732273553991], [[4, 0, 4, 2], 3.283414346005772]] Step 5 in searching: [[[4, 0, 4, 0, 4], 3.4657359027997265], [[4, 0, 4, 0, 3], 3.6888794541139363], [[4, 0, 4, 0, 2], 3.9765615265657175]] Step 6 in searching: [[[4, 0, 4, 0, 4, 0], 4.1588830833596715], [[4, 0, 4, 0, 4, 1], 4.382026634673881], [[4, 0, 4, 0, 4, 2], 4.669708707125663]] Step 7 in searching: [[[4, 0, 4, 0, 4, 0, 4], 4.852030263919617], [[4, 0, 4, 0, 4, 0, 3], 5.075173815233827], [[4, 0, 4, 0, 4, 0, 2], 5.362855887685607]] Step 8 in searching: [[[4, 0, 4, 0, 4, 0, 4, 0], 5.545177444479562], [[4, 0, 4, 0, 4, 0, 4, 1], 5.768320995793772], [[4, 0, 4, 0, 4, 0, 4, 2], 6.056003068245553]] Step 9 in searching: [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 6.238324625039508], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 6.461468176353717], [[4, 0, 4, 0, 4, 0, 4, 0, 2], 6.749150248805498]] ``` ```python # Results: [[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 6.931471805599453] [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 7.154615356913663] [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 7.154615356913663] ``` ### How about applying both methods in auto regression? + greedy search ```python def greedy_decode(model, input_seq, max_length=50): output_seq = [input_seq] for _ in range(max_length): logits = model(output_seq) next_token = argmax(logits) output_seq.append(next_token) if next_token == EOS_TOKEN: break return output_seq ``` + beam search ```python def beam_search(model, input_seq, k=3, max_length=50): beams = [(input_seq, 0)] # each beam is (sequence, score) for _ in range(max_length): new_beams = [] for seq, score in beams: logits = model(seq) probs = softmax(logits) topk_tokens, topk_probs = get_topk(probs, k) for token, prob in zip(topk_tokens, topk_probs): new_seq = seq + [token] new_score = score + log(prob) # Use log for numerical stability new_beams.append((new_seq, new_score)) # Sort and keep the top k sequences beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:k] # Stop early if all top sequences end with <eos> if all(seq[-1] == EOS_TOKEN for seq, _ in beams): break return beams[0] # return the top sequence ``` + Ours ```python def greedy_auto_regression( self, pixel_values, dummy_input_ids, max_length=None, eos_id=None ): # TODO(Weber): this step may have to refactor if implement beam search. result = [] for _ in range(cast(int, max_length) - 1): with torch.no_grad(): logits = self.model(pixel_values, dummy_input_ids) # (seq_len, batch, vocab_size) to (batch, seq_len, vocab_size) logits = logits.permute(1, 0, 2) next_token_logits = logits[:, -1] next_token = torch.argmax(next_token_logits) # update the dummy_input_ids for greedy decoding. dummy_input_ids = torch.concat( [dummy_input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1 ) # concat the result last part of logits. result.append(next_token) if next_token.int() == eos_id: break result = torch.stack(result) return result def beam_auto_regression( self, pixel_values, dummy_input_ids, max_length=None, eos_id=None, k=3 ): def get_topk(probs, k): topk_probs, topk_indices = torch.topk(probs, k) return topk_indices, topk_probs beams = [([], 0)] # each beam is (sequence, score) for _ in range(cast(int, max_length) - 1): new_beams = [] for seq, score in beams: with torch.no_grad(): logits = self.model(pixel_values, dummy_input_ids) # We need softmax here probs = self.softmax(logits) topk_tokens, topk_probs = get_topk(probs, k) for token, prob in zip(topk_tokens, topk_probs): new_seq = seq + [token] new_score = score + math.log(prob) # Use log for numerical stability new_beams.append((new_seq, new_score)) # Sort and keep the top k sequences beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:k] # Stop early if all top sequences end with <eos> if all(seq[-1] == eos_id for seq, _ in beams): break return beams[0][0] # return the top sequence ```