## PYQT最終程式
```python=
import sys
import argparse
import speech_recognition as sr
import pyttsx3
import threading
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QVBoxLayout, QWidget, QFileDialog, QComboBox, QTextEdit, QStyledItemDelegate
from PyQt5.QtGui import QPixmap, QPainter, QPen, QColor, QFont, QIcon
from PyQt5.QtCore import Qt, QSize, QRect
import torch
import torchvision.transforms as T
from PIL import Image
from models import build_model # 自定義的模型構建函數
from gtts import gTTS
import os
# 定義物件類別和關係類別,用於後續的物件檢測和關係檢測
CLASSES = [
'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra'
]
REL_CLASSES = [
'__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on',
'looking at', 'lying on', 'made of', 'mounted on', 'next to', 'of', 'on', 'on back of', 'over',
'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with'
]
# 定義參數解析器,用於從命令行獲取參數
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--dataset', default='vg')
parser.add_argument('--img_path', type=str, default='demo/vg1.jpg', help="Path of the test image")
parser.add_argument('--backbone', default='resnet50', type=str, help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true', help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int, help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int, help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int, help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_entities', default=100, type=int, help="Number of query slots")
parser.add_argument('--num_triplets', default=200, type=int, help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', help="Disables auxiliary decoding losses (loss at each layer)")
parser.add_argument('--device', default='cuda', help='device to use for training / testing')
parser.add_argument('--resume', default='ckpt/checkpoint0149_oi.pth', help='resume from checkpoint')
parser.add_argument('--set_cost_class', default=1, type=float, help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float, help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float, help="giou box coefficient in the matching cost")
parser.add_argument('--set_iou_threshold', default=0.7, type=float, help="giou box coefficient in the matching cost")
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--rel_loss_coef', default=1, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float, help="Relative classification weight of the no-object class")
parser.add_argument('--return_interm_layers', action='store_true', help="Return the fpn if there is the tag")
return parser
# 創建模型,載入預訓練權重
def create_model(args):
model, _, _ = build_model(args)
ckpt = torch.load(args.resume)
model.load_state_dict(ckpt['model'])
model.eval() # 設置模型為評估模式
return model
# 將邊界框從中心坐標格式轉換為左上角和右下角坐標格式
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
# 將邊界框從相對尺寸縮放回圖像的實際尺寸
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
# 自定義委託類,用於在QComboBox中顯示圖片和文字
class ImageDelegate(QStyledItemDelegate):
def paint(self, painter, option, index):
data = index.data(Qt.DecorationRole)
text = index.data(Qt.DisplayRole)
if isinstance(data, QIcon):
icon = data
rect = option.rect
painter.save()
# 繪製圖標
icon.paint(painter, QRect(rect.left(), rect.top(), 80, 80))
# 設置字體
font = QFont("Times New Roman", 14, QFont.Bold)
painter.setFont(font)
# 繪製文本
painter.drawText(QRect(rect.left() + 90, rect.top(), rect.width() - 90, rect.height()), Qt.AlignVCenter, text)
painter.restore()
else:
super().paint(painter, option, index)
def sizeHint(self, option, index):
size = super().sizeHint(option, index)
size.setHeight(80) # 調整高度
return size
# 主窗口類,處理用戶界面和事件
class MainWindow(QMainWindow):
def __init__(self, args, recognizer):
super().__init__()
self.args = args
self.recognizer = recognizer
self.img_path = None
self.detected_classes = []
self.setWindowTitle("Visual Relationship Detection")
self.setGeometry(100, 100, 1200, 900)
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
layout = QVBoxLayout()
# 標籤,顯示提示信息或圖片
self.label = QLabel("Upload an image to analyze relationships", self)
font = QFont()
font.setPointSize(20)
font.setFamily("Times New Roman")
font.setBold(True)
self.label.setFont(font)
layout.addWidget(self.label)
# 上傳按鈕,觸發上傳圖片功能
self.upload_btn = QPushButton("Upload Image", self)
font = QFont()
font.setPointSize(12)
font.setBold(True)
font.setFamily("Times New Roman")
self.upload_btn.setFont(font)
self.upload_btn.clicked.connect(self.upload_image)
layout.addWidget(self.upload_btn)
# 主題選擇下拉框1,用戶可以選擇感興趣的物件類別
self.item_selector1 = QComboBox(self)
font = QFont()
font.setPointSize(12)
font.setBold(True)
font.setFamily("Times New Roman")
self.item_selector1.setFont(font)
self.item_selector1.addItem("Choose the first subject")
self.item_selector1.setStyleSheet("QComboBox { padding: 1px; text-align: center; }"
"QComboBox QAbstractItemView { text-align: center; }")
layout.addWidget(self.item_selector1)
# 主題選擇下拉框2,用戶可以選擇感興趣的物件類別
self.item_selector2 = QComboBox(self)
font = QFont()
font.setPointSize(12)
font.setBold(True)
font.setFamily("Times New Roman")
self.item_selector2.setFont(font)
self.item_selector2.addItem("Choose the second subject")
self.item_selector2.setStyleSheet("QComboBox { padding: 1px; text-align: center; }"
"QComboBox QAbstractItemView { text-align: center; }")
layout.addWidget(self.item_selector2)
# 設置自定義委託
delegate = ImageDelegate(self.item_selector1)
self.item_selector1.setItemDelegate(delegate)
self.item_selector2.setItemDelegate(delegate)
# 分析按鈕,觸發分析圖片功能
self.analyze_btn = QPushButton("Analyze", self)
font = QFont()
font.setPointSize(12)
font.setBold(True)
font.setFamily("Times New Roman")
self.analyze_btn.setFont(font)
self.analyze_btn.clicked.connect(self.analyze_image)
layout.addWidget(self.analyze_btn)
# 結果文本框,顯示分析結果
self.results_text = QTextEdit(self)
self.results_text.setReadOnly(True)
font = QFont()
font.setPointSize(16)
font.setFamily("Times New Roman")
self.results_text.setFont(font)
layout.addWidget(self.results_text)
# 語音識別狀態標籤,顯示語音識別狀態
self.recognition_status = QLabel("", self)
font = QFont()
font.setPointSize(14)
font.setFamily("Times New Roman")
font.setBold(True)
self.recognition_status.setFont(font)
self.recognition_status.setStyleSheet("QLabel { background-color : lightgreen; color : black; }")
self.recognition_status.setAlignment(Qt.AlignCenter)
self.recognition_status.setVisible(False)
layout.addWidget(self.recognition_status)
self.central_widget.setLayout(layout)
self.engine = pyttsx3.init()
self.last_text = ""
# 上傳圖片,並顯示在標籤中
def upload_image(self):
self.img_path, _ = QFileDialog.getOpenFileName(self, "Open Image", "", "Image files (*.jpg *.jpeg *.png)")
if self.img_path:
pixmap = QPixmap(self.img_path)
scaled_pixmap = pixmap.scaledToWidth(800, Qt.SmoothTransformation)
self.label.setPixmap(scaled_pixmap)
self.label.setAlignment(Qt.AlignCenter)
self.label.setFixedSize(scaled_pixmap.size())
# 上傳圖片後進行物件檢測
self.detect_objects()
# 物件檢測,更新下拉菜單
def detect_objects(self):
model = create_model(self.args)
im = Image.open(self.img_path)
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform(im).unsqueeze(0)
outputs = model(img)
probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
detected_indices = torch.where(probas_sub.max(1).values > 0.3)[0]
detected_classes = set(CLASSES[i] for i in probas_sub[detected_indices].argmax(-1).tolist())
self.update_item_selector(detected_classes)
# 更新下拉菜單中的物件類別
def update_item_selector(self, detected_classes):
self.item_selector1.clear()
self.item_selector1.addItem("Choose the first subject")
self.item_selector2.clear()
self.item_selector2.addItem("Choose the second subject")
for cls in detected_classes:
icon = QIcon(f"images/{cls}.jpg")
self.item_selector1.addItem(icon, cls)
self.item_selector2.addItem(icon, cls)
# 在圖片上繪製邊界框,使用不同顏色區分不同物件
def draw_boxes_on_image(self, sub_bboxes, obj_bboxes, used_sub_bboxes, used_obj_bboxes, color, painter):
for sub_bbox, obj_bbox in zip(sub_bboxes, obj_bboxes):
obj_bbox_tuple = tuple(obj_bbox.tolist())
if obj_bbox_tuple not in used_obj_bboxes:
obj_color = QColor(color)
obj_color.setAlphaF(0.5)
painter.setPen(QPen(obj_color, 3))
obj_x1, obj_y1, obj_x2, obj_y2 = obj_bbox_tuple
painter.drawRect(obj_x1, obj_y1, obj_x2 - obj_x1, obj_y2 - obj_y1)
used_obj_bboxes.add(obj_bbox_tuple)
sub_bbox_tuple = tuple(sub_bbox.tolist())
if sub_bbox_tuple not in used_sub_bboxes:
sub_color = QColor(Qt.red)
sub_color.setAlphaF(0.5)
painter.setPen(QPen(sub_color, 3))
sub_x1, sub_y1, sub_x2, sub_y2 = sub_bbox_tuple
painter.drawRect(sub_x1, sub_y1, sub_x2 - sub_x1, sub_y2 - sub_y1)
used_sub_bboxes.add(sub_bbox_tuple)
# 將文本轉換為語音並播放
def text_to_speech(self, text):
self.engine.setProperty('rate', 200)
self.engine.setProperty('volume', 1)
self.engine.say(text)
self.engine.runAndWait()
tts = gTTS(text=text, lang='en')
tts.save("output.mp3")
os.system("mpg321 output.mp3")
threading.Thread(target=self.speech_recognition).start()
# 重複播放上次的音頻
def repeat_last_audio(self):
threading.Thread(target=self.text_to_speech, args=(self.last_text,)).start()
# 關閉窗口
def close_window(self):
self.close()
# 語音識別,監聽用戶的命令
def speech_recognition(self):
with sr.Microphone() as source:
self.recognition_status.setText("Listening for 'repeat or Thank you' command...")
self.recognition_status.setVisible(True)
audio = self.recognizer.listen(source, timeout=10) # 設置10秒的監聽時間
try:
text = self.recognizer.recognize_google(audio)
self.recognition_status.setText(f"Recognized: {text}")
if "repeat" in text.lower():
self.show_recognition_status("Recognized: repeat")
self.repeat_last_audio()
elif "thank you" in text.lower():
self.show_recognition_status("Recognized: thank you and goodbye good luck")
self.close_window()
except sr.UnknownValueError:
self.recognition_status.setText("Could not understand audio")
print("Could not understand audio")
except sr.RequestError as e:
self.recognition_status.setText(f"Could not request results; {e}")
print(f"Could not request results; {e}")
finally:
threading.Timer(3, lambda: self.recognition_status.setVisible(False)).start()
# 顯示語音識別狀態
def show_recognition_status(self, message):
self.recognition_status.setText(message)
self.recognition_status.setVisible(True)
threading.Timer(3, lambda: self.recognition_status.setVisible(False)).start()
# 分析圖片,檢測物件和它們之間的關係
def analyze_image(self):
selected_sub1 = self.item_selector1.currentText()
selected_sub2 = self.item_selector2.currentText()
if self.img_path:
result_text = self.detect_relationships(self.img_path, selected_sub1, selected_sub2)
threading.Thread(target=self.text_to_speech, args=(result_text,)).start()
else:
print("Please upload an image first.")
# 檢測圖片中的關係,繪製邊界框,並返回檢測結果
def detect_relationships(self, img_path, selected_sub1, selected_sub2):
model = create_model(self.args)
im = Image.open(img_path)
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform(im).unsqueeze(0)
outputs = model(img)
probas = outputs['rel_logits'].softmax(-1)[0, :, :-1]
probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1]
sub_bboxes = outputs['sub_boxes'][0]
obj_bboxes = outputs['obj_boxes'][0]
sub_bboxes_scaled = rescale_bboxes(sub_bboxes, im.size)
obj_bboxes_scaled = rescale_bboxes(obj_bboxes, im.size)
high_confidence_indices = torch.where(probas.max(1).values > 0.6)[0][:10]
relationships = []
for idx in high_confidence_indices:
sub_class = CLASSES[probas_sub[idx].argmax()]
obj_class = CLASSES[probas_obj[idx].argmax()]
rel_class = REL_CLASSES[probas[idx].argmax()]
relationships.append((sub_class, rel_class, obj_class, sub_bboxes_scaled[idx], obj_bboxes_scaled[idx]))
used_sub_bboxes = set()
used_obj_bboxes = set()
self.results_text.clear()
pixmap = QPixmap(self.img_path)
painter = QPainter(pixmap)
result_text = ""
result_count = 1 # 添加計數器
if selected_sub1 == "Choose the first subject" and selected_sub2 == "Choose the second subject":
result_text += "Relationships involving all objects:\n"
for sub_class, rel_class, obj_class, sub_bbox, obj_bbox in relationships:
self.draw_boxes_on_image([sub_bbox], [obj_bbox], used_sub_bboxes, used_obj_bboxes, Qt.yellow, painter)
result_text += f"{result_count}. In this picture, I look the {sub_class} in the front, and the {sub_class} {rel_class} {obj_class}\n"
result_count += 1
else:
if selected_sub1 != "Choose the first subject":
result_text += f"Relationships involving '{selected_sub1}':\n"
for sub_class, rel_class, obj_class, sub_bbox, obj_bbox in relationships:
if sub_class == selected_sub1:
self.draw_boxes_on_image([sub_bbox], [obj_bbox], used_sub_bboxes, used_obj_bboxes, Qt.green, painter)
result_text += f"{result_count}. In this picture, I look the {sub_class} in the front, and the {sub_class} {rel_class} {obj_class}\n"
result_count += 1
if selected_sub2 != "Choose the second subject":
result_text += f"Relationships involving '{selected_sub2}':\n"
for sub_class, rel_class, obj_class, sub_bbox, obj_bbox in relationships:
if sub_class == selected_sub2:
self.draw_boxes_on_image([sub_bbox], [obj_bbox], used_sub_bboxes, used_obj_bboxes, Qt.blue, painter)
result_text += f"{result_count}. In this picture, I look the {sub_class} in the front, and the {sub_class} {rel_class} {obj_class}\n"
result_count += 1
painter.end()
self.label.setPixmap(pixmap)
self.results_text.setText(result_text)
self.last_text = result_text
return result_text
# 處理按鍵事件,重複播放上次音頻
def keyPressEvent(self, event):
if event.key() == Qt.Key_R:
self.repeat_last_audio()
# 主函數,初始化參數解析器、語音識別器和應用程序
if __name__ == '__main__':
parser = argparse.ArgumentParser('RelTR inference', parents=[get_args_parser()])
args = parser.parse_args()
recognizer = sr.Recognizer()
app = QApplication(sys.argv)
window = MainWindow(args, recognizer)
window.show()
sys.exit(app.exec_())
```
## 直接做語音辨識
```python=
import argparse
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
from models import build_model # 從自定義的模組中導入建立模型的函數
from gtts import gTTS # 文字轉語音套件
import pyttsx3 # 文字轉語音套件
import threading
import speech_recognition as sr # 語音辨識套件
def get_args_parser():
# 定義命令行參數
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--dataset', default='vg')
parser.add_argument('--img_path', type=str, default='demo/vg1.jpg',
help="Path of the test image") # 測試圖像的路徑
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use") # 使用的卷積主幹模型名稱
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_entities', default=100, type=int,
help="Number of query slots")
parser.add_argument('--num_triplets', default=200, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
help="Disables auxiliary decoding losses (loss at each layer)")
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--resume', default='ckpt/checkpoint0149_oi.pth', help='resume from checkpoint')
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
parser.add_argument('--set_iou_threshold', default=0.7, type=float,
help="giou box coefficient in the matching cost")
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--rel_loss_coef', default=1, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
help="Relative classification weight of the no-object class")
parser.add_argument('--return_interm_layers', action='store_true',
help="Return the fpn if there is the tag")
return parser
def speech_engine(text, recognizer):
# 文字轉語音
engine = pyttsx3.init() # 初始化文字轉語音引擎
voices = engine.getProperty('voices')
engine.setProperty('voice', voices[10].id) # 選擇語音類型
engine.setProperty('rate', 180) # 設置語速
engine.setProperty('volume', 1) # 設置音量
engine.say(text) # 將文字添加到語音引擎
engine.runAndWait() # 等待語音完成播放
# 開始語音辨識
recognized_text = handle_audio(recognizer) # 使用語音辨識器處理語音輸入
if recognized_text.lower() == "repeat": # 如果辨識結果是"repeat",則重複播放文字
threading.Thread(target=speech_engine, args=(text, recognizer)).start() # 在新的執行緒中重複播放文字
def handle_audio(recognizer):
# 處理語音輸入
with sr.Microphone() as source: # 使用麥克風作為音源
print("Please speak...") # 提示用戶開始說話
try:
audio_data = recognizer.listen(source) # 監聽麥克風輸入
text = recognizer.recognize_google(audio_data, language='en-US') # 使用Google語音辨識API辨識語音為文字
print("You said:", text) # 顯示辨識結果
return text
except sr.UnknownValueError:
print("Sorry, I could not understand your audio input.") # 若無法辨識,提示用戶
return ""
except sr.RequestError as e:
print("Could not request results from Google Speech Recognition API; {0}".format(e)) # 若請求失敗,提示用戶
return ""
def main(args):
ordinal_numbers = ['First', 'Second', 'Third', 'Fourth', 'Fifth', 'Sixth', 'Seventh', 'Eighth', 'Ninth', 'Tenth']
transform = T.Compose([
T.Resize(800), # 調整圖像大小為800x800像素
T.ToTensor(), # 轉換圖像為Tensor格式
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 正規化圖像
])
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
CLASSES = [ 'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
'sock', 'stand', 'street', 'surfboard', 'Table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']
REL_CLASSES = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
'from', 'growing on', 'hanging from','has', 'holding', 'in', 'in front of', 'laying on',
'looking at', 'lying on', 'made of', 'mounted on', 'next to', 'of', 'on', 'on back of', 'over',
'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']
model, _, _ = build_model(args) # 建立模型
ckpt = torch.load(args.resume)
model.load_state_dict(ckpt['model']) # 載入模型權重
model.eval() # 設定模型為評估模式
img_path = args.img_path
im = Image.open(img_path) # 開啟圖像文件
img = transform(im).unsqueeze(0) # 轉換圖像為模型所需格式
outputs = model(img) # 使用模型進行推論
# 從輸出中獲取物件的概率
probas = outputs['rel_logits'].softmax(-1)[0, :, :-1]
probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1]
# 過濾出概率較高的物件
keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3,
probas_obj.max(-1).values > 0.3))
# 將邊界框尺度轉換為圖像尺寸
sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size)
obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size)
topk = 5
keep_queries = torch.nonzero(keep, as_tuple=True)[0]
indices = torch.argsort(-probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[0])[:topk]
keep_queries = keep_queries[indices]
conv_features, dec_attn_weights_sub, dec_attn_weights_obj = [], [], []
hooks = [
model.backbone[-2].register_forward_hook(
lambda self, input, output: conv_features.append(output)
),
model.transformer.decoder.layers[-1].cross_attn_sub.register_forward_hook(
lambda self, input, output: dec_attn_weights_sub.append(output[1])
),
model.transformer.decoder.layers[-1].cross_attn_obj.register_forward_hook(
lambda self, input, output: dec_attn_weights_obj.append(output[1])
)
]
with torch.no_grad():
outputs = model(img)
for hook in hooks:
hook.remove()
conv_features = conv_features[0]
dec_attn_weights_sub = dec_attn_weights_sub[0]
dec_attn_weights_obj = dec_attn_weights_obj[0]
h, w = conv_features['0'].tensors.shape[-2:]
im_w, im_h = im.size
fig, axs = plt.subplots(ncols=len(indices), nrows=3, figsize=(22, 7))
for idx, ax_i, (sxmin, symin, sxmax, symax), (oxmin, oymin, oxmax, oymax) in \
zip(keep_queries, axs.T, sub_bboxes_scaled[indices], obj_bboxes_scaled[indices]):
ax = ax_i[0]
ax.imshow(dec_attn_weights_sub[0, idx].view(h, w))
ax.axis('off')
ax.set_title(f'query id: {idx.item()}')
ax = ax_i[1]
ax.imshow(dec_attn_weights_obj[0, idx].view(h, w))
ax.axis('off')
ax = ax_i[2]
ax.imshow(im)
ax.add_patch(plt.Rectangle((sxmin, symin), sxmax - sxmin, symax - symin,
fill=False, color='blue', linewidth=2.5))
ax.add_patch(plt.Rectangle((oxmin, oymin), oxmax - oxmin, oymax - oymin,
fill=False, color='orange', linewidth=2.5))
ax.axis('off')
ax.set_title(CLASSES[probas_sub[idx].argmax()]+' '+REL_CLASSES[probas[idx].argmax()]+' '+CLASSES[probas_obj[idx].argmax()], fontsize=10)
detected_relations_text = '. '.join([
f"{ordinal_numbers[idx]}: In this picture, I look the {CLASSES[probas_sub[query_idx].argmax()]} in the front, and the {CLASSES[probas_sub[query_idx].argmax()]} {REL_CLASSES[probas[query_idx].argmax()]} {CLASSES[probas_obj[query_idx].argmax()]}"
for idx, query_idx in enumerate(keep_queries[:len(ordinal_numbers)])
])
if detected_relations_text:
tts = gTTS(detected_relations_text, lang='en') # 將偵測到的關係文字轉換為語音
tts.save("detected_relations.mp3") # 將語音儲存為MP3格式
print("Detected relationships have been converted to speech and saved as 'detected_relations.mp3'.")
threading.Thread(target=speech_engine, args=(detected_relations_text, recognizer)).start() # 在新的執行緒中播放語音
else:
print("No relationships detected.")
fig.tight_layout()
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser('RelTR inference', parents=[get_args_parser()])
args = parser.parse_args()
recognizer = sr.Recognizer() # 初始化語音辨識器
main(args)
```
## main.py檔
```python=
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Institute of Information Processing, Leibniz University Hannover.
import argparse
import datetime
import json
import random
import time
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
import datasets
import util.misc as utils
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr_drop', default=100, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
# Model parameters
parser.add_argument('--frozen_weights', type=str, default=None,
help="Path to the pretrained model. If set, only the mask head will be trained")
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_entities', default=100, type=int,
help="Number of query slots")
parser.add_argument('--num_triplets', default=200, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
help="Disables auxiliary decoding losses (loss at each layer)")
# * Matcher
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
parser.add_argument('--set_iou_threshold', default=0.7, type=float,
help="giou box coefficient in the matching cost")
# * Loss coefficients
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--rel_loss_coef', default=1, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
help="Relative classification weight of the no-object class")
# dataset parameters
parser.add_argument('--dataset', default='vg')
parser.add_argument('--ann_path', default='./data/vg/', type=str)
parser.add_argument('--img_folder', default='/home/cong/Dokumente/tmp/data/visualgenome/images/', type=str)
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--num_workers', default=2, type=int)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--return_interm_layers', action='store_true',
help="Return the fpn if there is the tag")
return parser
def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
param_dicts = [
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
dataset_train = build_dataset(image_set='train', args=args)
dataset_val = build_dataset(image_set='val', args=args)
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
base_ds = get_coco_api_from_dataset(dataset_val)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
model_without_ddp.detr.load_state_dict(checkpoint['model'])
output_dir = Path(args.output_dir)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
# del checkpoint['optimizer']
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.eval:
print('It is the {}th checkpoint'.format(checkpoint['epoch']))
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth'] # anti-crash
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 5 == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser('RelTR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
```
## 按r鍵重播語音
```python=
import argparse # 用於解析命令行參數
from PIL import Image # 用於圖像處理
import matplotlib.pyplot as plt # 用於繪製圖表
import torch # PyTorch 深度學習庫
import torchvision.transforms as T # PyTorch 的圖像轉換模塊
from models import build_model # 自定義模型構建函數
from gtts import gTTS # 用於文本轉語音
import pyttsx3 # 用於語音合成
import threading # 用於多線程處理
from pynput import keyboard # 用於監聽鍵盤事件
# 定義函數,解析命令行參數
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr_backbone', default=1e-5, type=float) # 骨幹網絡的初始學習率
parser.add_argument('--dataset', default='vg') # 數據集
parser.add_argument('--img_path', type=str, default='demo/vg1.jpg', # 測試圖像的路徑
help="Path of the test image")
parser.add_argument('--backbone', default='resnet50', type=str, # 使用的卷積骨幹網絡
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true', # 是否使用空洞卷積
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), # 位置嵌入類型
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--enc_layers', default=6, type=int, # 變換器中的編碼層數
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int, # 變換器中的解碼層數
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int, # 變換器塊中的前馈層維度
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int, # 嵌入向量的維度
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float, # Dropout 比例
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int, # 注意力頭數
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_entities', default=100, type=int, # 查詢槽的數量
help="Number of query slots")
parser.add_argument('--num_triplets', default=200, type=int, # 查詢槽的數量
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true') # 是否在層正則化之前應用正規化
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', # 是否啟用輔助解碼損失
help="Disables auxiliary decoding losses (loss at each layer)")
parser.add_argument('--device', default='cuda', # 設備(GPU 或 CPU)
help='device to use for training / testing')
parser.add_argument('--resume', default='ckpt/checkpoint0149_oi.pth', # 恢復訓練時的模型檢查點路徑
help='resume from checkpoint')
parser.add_argument('--set_cost_class', default=1, type=float, # 匹配成本中的類別權重
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float, # 匹配成本中的 L1 盒子權重
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float, # 匹配成本中的 giou 盒子權重
help="giou box coefficient in the matching cost")
parser.add_argument('--set_iou_threshold', default=0.7, type=float, # IoU 降低閾值
help="giou box coefficient in the matching cost")
parser.add_argument('--bbox_loss_coef', default=5, type=float) # 盒子損失權重
parser.add_argument('--giou_loss_coef', default=2, type=float) # Giou 損失權重
parser.add_argument('--rel_loss_coef', default=1, type=float) # 關係損失權重
parser.add_argument('--eos_coef', default=0.1, type=float, # 無對象類別的相對分類權重
help="Relative classification weight of the no-object class")
parser.add_argument('--return_interm_layers', action='store_true', # 是否返回 fpn
help="Return the fpn if there is the tag")
return parser
# 定義語音合成函數
def speech_engine(text):
engine = pyttsx3.init() # 初始化語音合成引擎
voices = engine.getProperty('voices') # 獲取語音引擎的聲音列表
engine.setProperty('voice', voices[10].id) # 選擇一個聲音
engine.setProperty('rate', 180) # 設置語速(words per minute)
engine.setProperty('volume', 1) # 設置音量(0.0 到 1.0)
engine.say(text) # 將文本添加到待語音合成的隊列中
engine.runAndWait() # 等待所有隊列中的語音合成任務完成
# 定義主函數
def main(args):
ordinal_numbers = ['First', 'Second', 'Third', 'Fourth', 'Fifth', 'Sixth', 'Seventh', 'Eighth', 'Ninth', 'Tenth']
# 定義圖像轉換操作
transform = T.Compose([
T.Resize(800), # 調整圖像大小
T.ToTensor(), # 將圖像轉換為 Tensor
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 正規化
])
# 定義函數,將中心坐標和寬高形式的邊界框轉換為左上角和右下角坐標形式
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
# 定義函數,將邊界框的比例大小重新縮放到原始圖像大小
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
# 定義類別和關係列表
CLASSES = [ 'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
'sock', 'stand', 'street', 'surfboard', 'Table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']
REL_CLASSES = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
'from', 'growing on', 'hanging from','has', 'holding', 'in', 'in front of', 'laying on',
'looking at', 'lying on', 'made of', 'mounted on', 'next to', 'of', 'on', 'on back of', 'over',
'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']
# 創建模型並加載權重
model, _, _ = build_model(args)
ckpt = torch.load(args.resume)
model.load_state_dict(ckpt['model'])
model.eval()
# 加載圖像並進行預處理
img_path = args.img_path
im = Image.open(img_path)
img = transform(im).unsqueeze(0)
# 模型推理
outputs = model(img)
# 提取預測結果
probas = outputs['rel_logits'].softmax(-1)[0, :, :-1]
probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1]
keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3,
probas_obj.max(-1).values > 0.3))
sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size)
obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size)
topk = 5
keep_queries = torch.nonzero(keep, as_tuple=True)[0]
indices = torch.argsort(-probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[0])[:topk]
keep_queries = keep_queries[indices]
conv_features, dec_attn_weights_sub, dec_attn_weights_obj = [], [], []
# 定義 hooks,用於捕獲模型中間層的特徵和注意力權重
hooks = [
model.backbone[-2].register_forward_hook(
lambda self, input, output: conv_features.append(output)
),
model.transformer.decoder.layers[-1].cross_attn_sub.register_forward_hook(
lambda self, input, output: dec_attn_weights_sub.append(output[1])
),
model.transformer.decoder.layers[-1].cross_attn_obj.register_forward_hook(
lambda self, input, output: dec_attn_weights_obj.append(output[1])
)
]
with torch.no_grad():
outputs = model(img)
for hook in hooks:
hook.remove()
conv_features = conv_features[0]
dec_attn_weights_sub = dec_attn_weights_sub[0]
dec_attn_weights_obj = dec_attn_weights_obj[0]
h, w = conv_features['0'].tensors.shape[-2:]
im_w, im_h = im.size
# 繪製圖表並顯示預測結果
fig, axs = plt.subplots(ncols=len(indices), nrows=3, figsize=(22, 7))
for idx, ax_i, (sxmin, symin, sxmax, symax), (oxmin, oymin, oxmax, oymax) in \
zip(keep_queries, axs.T, sub_bboxes_scaled[indices], obj_bboxes_scaled[indices]):
ax = ax_i[0]
ax.imshow(dec_attn_weights_sub[0, idx].view(h, w))
ax.axis('off')
ax.set_title(f'query id: {idx.item()}')
ax = ax_i[1]
ax.imshow(dec_attn_weights_obj[0, idx].view(h, w))
ax.axis('off')
ax = ax_i[2]
ax.imshow(im)
ax.add_patch(plt.Rectangle((sxmin, symin), sxmax - sxmin, symax - symin,
fill=False, color='blue', linewidth=2.5))
ax.add_patch(plt.Rectangle((oxmin, oymin), oxmax - oxmin, oymax - oymin,
fill=False, color='orange', linewidth=2.5))
ax.axis('off')
ax.set_title(CLASSES[probas_sub[idx].argmax()]+' '+REL_CLASSES[probas[idx].argmax()]+' '+CLASSES[probas_obj[idx].argmax()], fontsize=10)
# 構造檢測到的關係文本
detected_relations_text = '. '.join([
f"{ordinal_numbers[idx]}: In this picture, I look the {CLASSES[probas_sub[query_idx].argmax()]} in the front, and the {CLASSES[probas_sub[query_idx].argmax()]} {REL_CLASSES[probas[query_idx].argmax()]} {CLASSES[probas_obj[query_idx].argmax()]}"
for idx, query_idx in enumerate(keep_queries[:len(ordinal_numbers)])
])
# 將檢測到的關係文本轉換為語音並保存
if detected_relations_text:
tts = gTTS(detected_relations_text, lang='en') # 使用 gTTS 將文本轉換為語音
tts.save("detected_relations.mp3") # 將語音保存為 MP3 文件
print("Detected relationships have been converted to speech and saved as 'detected_relations.mp3'.")
threading.Thread(target=speech_engine, args=(detected_relations_text,)).start() # 在新線程中播放語音
else:
print("No relationships detected.") # 如果沒有檢測到關係,則輸出消息
fig.tight_layout()
# 定義按鍵事件處理函數
def on_press(key):
if key == keyboard.Key.esc:
plt.close(fig) # 關閉圖表窗口
return False
if key.char == 'r':
threading.Thread(target=speech_engine, args=(detected_relations_text,)).start() # 按下 'r' 鍵時,播放語音
listener = keyboard.Listener(on_press=on_press) # 監聽鍵盤事件
listener.start() # 啟動監聽器
plt.show() # 顯示圖表
listener.join() # 等待監聽器結束
if __name__ == '__main__':
parser = argparse.ArgumentParser('RelTR inference', parents=[get_args_parser()]) # 創建命令行參數解析器
args = parser.parse_args() # 解析命令行參數
main(args) # 執行主函數
```