### Reference
1. **論文** [Tree of Thoughts: Deliberate Problem Solving with LLMs](https://arxiv.org/pdf/2305.10601.pdf)
2. **Github code** [princeton-nlp/tree-of-thought-llm (Github)](https://github.com/princeton-nlp/tree-of-thought-llm)
3. **論文導讀** [Tree of Thoughts: Deliberate Problem Solving with LLMs 論文導讀](https://hackmd.io/@chrizeroxtwo/BJCRMevwh)
#### 使用模組
```python!
import re
import json
import os
from tasks.base import Task, DATA_PATH
from prompts.crosswords import *
from models import gpt
```
#### 定義解題環境
```python!
class MiniCrosswordsEnv:
def __init__(self, file='mini0505.json'):
self.file = f'data/crosswords/{file}' #題目檔案路徑
self.file = json.load(open(self.file)) #打開JavaScript Object Notation檔
self.n = len(self.file) #題目數量
self.cache = {} #暫存的一個render狀態的dict
self.idx = None #題目編號
self.times = 0
self.prompt_status_cache = {}
```
```python!
def __len__(self): #return題目數
return self.n
def reset(self, idx, board=None, status=None, steps=None):
self.idx = idx #題目編號
self.data, self.board_gt = self.file[idx] #data是crossword提示 board_gt是答案(字母)
self.board = ['_'] * 25 #25個回答的字母
self.ans = ['_____'] * 10 #10個回答的字
self.ans_gt = self.get_ans(self.board_gt) #board_gt是答案(字母) ['A', 'G', 'E', 'N', 'D'...,'E', 'R'] -> ['AGEND', 'MOTOR',...'DRYER']
self.steps = 0
self.status = [0] * 10 # 0: unfilled; 1: filled; 2: filled then changed #紀錄這個v/h的字是否已經沒回答/回答了/有改答案
if board is not None:
self.board = board #變成input的5x5樣板
self.ans = self.get_ans(self.board) #變成回答的字的形式
if status is not None:
self.status = status #更新字的狀態
if steps is not None:
self.steps = steps #更新步數
return self.render()
def prompt_status(self):
count = {'sure': 0, 'maybe': 0, 'impossible': 0}
for ans, data, status in zip(self.ans, self.data, self.status):
# if status != 0: continue
if ans.count('_') >= 4: continue
ans = ' '.join(ans.lower())
line = f'{data}: {ans}' #line: 對照提示跟自己的答案
prompt = value_prompt.format(input=line) #把line包進詢問value的prompt裡面
if prompt in self.prompt_status_cache: #看一樣prompt的output有沒有已經存在了
res = self.prompt_status_cache[prompt]
else:
res = gpt(prompt)[0]
self.prompt_status_cache[prompt] = res
# print(line)
# print(res)
# print()
res = res.split('\n')[-1].strip() #output的最後一行,對應count的key,strip刪開頭or結尾的space
if res in count: count[res] += 1
# print(count)
return count
def render_gt_board(self): #正確答案的5x5形式
s = "GT Board:\n"
for i in range(5):
s += ' '.join(self.board_gt[i*5:(i+1)*5]) + '\n'
return s
def render_board(self): #現在狀態的5x5形式
s = "Current Board:\n"
for i in range(5):
s += ''.join(self.board[i*5:(i+1)*5]) + '\n'
return s
def render_clues(self, status=None):
s = "" #秀出每個字的提示
# s += "Horizontal:\n"
for i in range(5):
if status is None or self.status[i] == status:
s += 'h' + str(i+1) + '. ' + self.data[i] + '\n'
# s += "Vertical:\n"
for i in range(5, 10):
if status is None or self.status[i] == status:
s += 'v' + str(i-5+1) + '. ' + self.data[i] + '\n'
return s
def render_ans(self, status=None):
s = "" #比對提示與目前答案的字串
# s += "Horizontal:\n"
for i in range(5):
if status is None or self.status[i] == status:
s += 'h' + str(i+1) + '. ' + self.data[i] + ': ' + self.ans[i] + '\n'
# s += "Vertical:\n"
for i in range(5, 10):
if status is None or self.status[i] == status:
s += 'v' + str(i-5+1) + '. ' + self.data[i] + ': ' + self.ans[i] + '\n'
return s
def render_gt_ans(self, status=None):
s = "" #比對提示與真的答案(有gt的都是ground truth)
# s += "Horizontal:\n"
for i in range(5):
if status is None or self.status[i] == status:
s += 'h' + str(i+1) + '. ' + self.data[i] + ': ' + self.ans_gt[i] + '\n'
# s += "Vertical:\n"
for i in range(5, 10):
if status is None or self.status[i] == status:
s += 'v' + str(i-5+1) + '. ' + self.data[i] + ': ' + self.ans_gt[i] + '\n'
return s
def render(self, status=True):
#秀出10個字的答題情況
if status:
return self.render_board() + '\nUnfilled:\n' + self.render_ans(status=0) + '\nFilled:\n' + self.render_ans(status=1) + '\nChanged:\n' + self.render_ans(status=2)
else:
return self.render_board() + '\n' + self.render_ans()
def get_ans(self, board): #讀取5x5的mini_crossword表['A', 'G', 'E', 'N', 'D'...,'E', 'R'] -> ['AGEND', 'MOTOR',...'DRYER']
self.steps = 0
ans = [''] * 10
for i in range(5):
ans[i] = ''.join(board[i*5:(i+1)*5])
for i in range(5):
ans[i+5] = ''.join(board[i::5])
return ans
def step(self, action): #takes only one word each time
self.steps += 1
action = action.split('\n')[-1]
action = action.split('. ')
if len(action) != 2:
return 'Invalid! Format should be like "h1. apple"', 0, False, {}
pos, word = action
if len(word) != 5:
return 'Invalid! Word should have 5 letters.', 0, False, {}
if pos.startswith('h'):
idx = int(pos[1:]) - 1
self.board[idx*5:(idx+1)*5] = list(word.upper())
elif pos.startswith('v'):
idx = int(pos[1:]) - 1
self.board[idx::5] = list(word.upper())
idx += 5 # for later status update 垂直的字在ans的後五個,後面更新狀態status是以ans的形式(只有10個)
else:
return 'Invalid! Position should be h1-h5 or v1-v5', 0, False, {}
self.new_ans = self.get_ans(self.board)
# self.status = [2 if (status == 1 and ans != new_ans) else status for status, ans, new_ans in zip(self.status, self.ans, self.new_ans)]
self.status = [2 if any(letter != new_letter and letter != '_' for letter, new_letter in zip(ans, new_ans)) else status for status, ans, new_ans in zip(self.status, self.ans, self.new_ans)]
#一次是更新一個字版board,會影響到另一個方向的字,原本有字母的地方只要被影響到,就會更新status=2
self.status[idx] = 1 #被填的字更新status=1
self.ans = self.new_ans
r_all = (self.board == self.board_gt) #對答案真值表
r_letter = sum(a == b for a, b in zip(self.board, self.board_gt)) / 25
r_word = sum(a == b for a, b in zip(self.ans, self.ans_gt)) / 10
return self.render(), r_all, (r_all or self.steps >= 20), {'r_letter': r_letter, 'r_word': r_word, 'r_game': r_all}
```
### DFS Impletmentation
```python!
import re
import copy
import json
from models import gpt
from prompts.crosswords import propose_prompt, value_prompt
from models import gpt
from tasks.crosswords import MiniCrosswordsEnv
env = MiniCrosswordsEnv()
def parse_line(input_str):
# regular expression pattern to match the input string format
pattern = r'^([hv][1-5])\. ([a-zA-Z]{5,5}) \((certain|high|medium|low)\).*$'
# use regex to extract the parts of the input string
match = re.match(pattern, input_str)
if match:
# extract the matched groups
parts = [match.group(1), match.group(2), match.group(3)]
return parts
else:
return None
confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1} # TODO: ad hoc
def parse_response(response):
# split the response into lines
lines = response.split('\n')
# parse each line
parsed_lines = [parse_line(line) for line in lines]
# filter out the lines that didn't match the format
parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]
return parsed_lines if len(parsed_lines) >= 1 else None
def get_candidates_to_scores(env): #Given current solving state, GPT renders evaluation
obs = env.render() #以解題狀態當作cache的key
if obs in env.cache:
print('cache hit')
return env.cache[obs]
print('call gpt')
responses = gpt(prompt_wrap(obs), model='gpt-4', n=8) #prompt GPT with
'''obs(解題狀態)包進Let's play a 5 x 5 mini crossword,
where each word should have exactly 5 letters.
{input} <-- (obs)
Given the current status,
list all possible answers for unfilled or changed words,
and your confidence levels (certain/high/medium/low),
using the format "h1. apple (medium)".
Use "certain" cautiously and only when you are 100% sure this is the correct word.
You can list more then one possible answer for each word.
n= 8 代表給8個thought candidates:
res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=cnt, stop=stop)
outputs.extend([choice["message"]["content"] for choice in res["choices"]])'''
candidates_to_scores = {}
for response in responses:
parsed_response = parse_response(response)
if parsed_response:
for candidate, score in parsed_response:
candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score
# choose candiate with highest score
# print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
env.cache[obs] = candidates_to_scores
return candidates_to_scores
#candidates_to_scores -> e.g. {'h1. award': 1, 'h2. stand': 0.5, 'v1. asite': 0.1}
def propose_score(env, idx):
obs = env.reset(idx) #初始化第idx題的解題狀態
done = False
infos = []
while not done:
responses = gpt(prompt_wrap(obs), model='gpt-4', n=5) #n:生成n個不同的GPT output
candidates_to_scores = {}
for response in responses:
parsed_response = parse_response(response)
if parsed_response:
for candidate, score in parsed_response:
candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score
# choose candiate with highest score
print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
if len(candidates_to_scores) == 0:
break
candidates = sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)
for candidate in candidates:
env_ = copy.deepcopy(env)
env_.step(candidate)
if not any(_ == 2 for _ in env_.status):
break
print(candidate)
# candidate = input()
obs, r, done, info = env.step(candidate)
#return self.render(), r_all, (r_all or self.steps >= 20),
#{'r_letter': r_letter, 'r_word': r_word, 'r_game': r_all} 字母、字的正確率跟是否通關
#done = true 當答案全對or達到步數上限
print(obs)
print(env.steps, info)
print('-------------------\n\n\n')
infos.append(info)
return infos
```
```python!
def dfs(env, actions, infos, time_limit, prune, max_per_state):
# get candidate thoughts
candidates_to_scores = get_candidates_to_scores(env)
if len(candidates_to_scores) == 0: return 0, [], []
print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
# back up current state
board, status, steps = env.board.copy(), env.status.copy(), env.steps
# try each candidate
cnt_per_state = 0
for action in sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True):
obs, r, done, info = env.step(action)
r = info['r_word']
if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status): # not violating any existing constraints
cnt_per_state += 1
if cnt_per_state > max_per_state: break
count = env.prompt_status()
actions.append(action)
print(len(infos))
print(actions)
print(env.render_board())
print(info)
print(count)
if infos:
best = max(infos, key=lambda x: x['info']['r_word'])
print('best', best)
print('--------------')
print()
info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}
infos.append(info)
if not prune or count['impossible'] < 1: # only continue if the current status is possible
dfs(env, actions, infos, time_limit, prune, max_per_state)
actions.pop()
env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)
# dfs with pruning
infoss = []
for i in range(0, 100, 5): #取樣20題
env.reset(i)
infos = []
actions = []
dfs(env, actions, infos, 100, prune=True, max_per_state=3)
infoss.append(infos)
with open('logs/crosswords/infoss_dfs_prune.json', 'w') as fout:
json.dump(infoss, fout)
```
### Model
```python!
import os
import openai
import backoff
completion_tokens = prompt_tokens = 0
api_key = os.getenv("OPENAI_API_KEY", "")
if api_key != "":
openai.api_key = api_key
else:
print("Warning: OPENAI_API_KEY is not set")
api_base = os.getenv("OPENAI_API_BASE", "")
if api_base != "":
print("Warning: OPENAI_API_BASE is set to {}".format(api_base))
openai.api_base = api_base
@backoff.on_exception(backoff.expo, openai.error.OpenAIError)
def completions_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)
def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
messages = [{"role": "user", "content": prompt}]
return chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
def chatgpt(messages, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
global completion_tokens, prompt_tokens
outputs = []
while n > 0:
cnt = min(n, 20)
n -= cnt
res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=cnt, stop=stop)
outputs.extend([choice["message"]["content"] for choice in res["choices"]])
# log completion tokens
completion_tokens += res["usage"]["completion_tokens"]
prompt_tokens += res["usage"]["prompt_tokens"]
return outputs
def gpt_usage(backend="gpt-4"):
global completion_tokens, prompt_tokens
if backend == "gpt-4":
cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
elif backend == "gpt-3.5-turbo":
cost = (completion_tokens + prompt_tokens) / 1000 * 0.0002
return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
```