你已經成功建立與訓練一個 MIMO-UNet 模型來處理模糊圖片,但你觀察到推論出來的圖片依然模糊,代表 **模型的學習效果有限或尚未收斂**。這很常見,我幫你整理出一套「強化訓練效果」的策略,從**資料、模型、訓練流程**全面優化。 --- ## 🔧 一、資料相關調整 ### ✅ 1. 增加資料多樣性與數量 - 資料量太少會導致模型無法泛化,建議: - 收集更多類型的模糊圖片(不同模糊程度、方向、背景) - 使用模糊核自動產生模糊圖片 (`cv2.GaussianBlur`、motion blur) ### ✅ 2. 增加資料增強(Data Augmentation) 目前 Dataset 中僅做了 Resize 與 ToTensor,建議加上: ```python self.transform = T.Compose([ T.Resize((256, 256)), T.RandomHorizontalFlip(), T.RandomRotation(10), T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1), T.ToTensor() ]) ``` --- ## 🧠 二、模型架構強化(可選升級) ### ✅ 1. 加入 `BatchNorm2d`:提升穩定性 修改 `ResidualBlock`: ```python class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.block = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d(channels, channels, 3, padding=1), nn.BatchNorm2d(channels) ) def forward(self, x): return x + self.block(x) ``` ### ✅ 2. 替換 Loss 函數為更感知的損失 目前用的是 `MSELoss`,但這對視覺品質不夠敏感。建議改用: ```python import torchvision.models as models class PerceptualLoss(nn.Module): def __init__(self): super().__init__() vgg = models.vgg16(pretrained=True).features[:16].eval() for param in vgg.parameters(): param.requires_grad = False self.vgg = vgg self.loss_fn = nn.MSELoss() def forward(self, pred, target): vgg_pred = self.vgg(pred) vgg_target = self.vgg(target) return self.loss_fn(vgg_pred, vgg_target) ``` 或者組合兩者: ```python loss_fn = lambda pred, target: mse_loss(pred, target) + 0.1 * perceptual_loss(pred, target) ``` --- ## 🚀 三、訓練流程強化建議 ### ✅ 1. 增加 Epoch 數 你的 30 epochs 可能不足,建議先訓練 100+ epochs(儲存最佳模型)。 ### ✅ 2. 使用學習率衰減策略 加上: ```python scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5) ``` 在每個 epoch 結尾加入: ```python scheduler.step(total_loss) ``` ### ✅ 3. 加入 Validation 集與評估指標(PSNR / SSIM) 這有助於觀察模型是否過擬合: ```python from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim def evaluate(model, loader, device): model.eval() psnr_total, ssim_total = 0, 0 with torch.no_grad(): for blur, sharp in loader: blur, sharp = blur.to(device), sharp.to(device) output = model(blur).clamp(0, 1) for pred, gt in zip(output, sharp): pred_img = pred.permute(1, 2, 0).cpu().numpy() gt_img = gt.permute(1, 2, 0).cpu().numpy() psnr_total += psnr(gt_img, pred_img, data_range=1.0) ssim_total += ssim(gt_img, pred_img, channel_axis=-1, data_range=1.0) print(f"[Eval] PSNR: {psnr_total/len(loader.dataset):.2f}, SSIM: {ssim_total/len(loader.dataset):.4f}") ``` --- ## ✅ 小結(調整重點整理) | 項目 | 建議 | |------|------| | 📁 資料 | 增加樣本數、多樣性、資料增強 | | 🧠 模型 | 加入 BatchNorm、可選 Perceptual Loss | | 🔧 損失函數 | MSE + Perceptual Loss 組合更穩定 | | 📊 評估 | 加入 PSNR/SSIM、Validation dataset | | 🔁 訓練 | Epoch 100+、使用學習率調整策略 | --- ## 🎁 Bonus:我可以幫你 - 加入完整的 **可視化曲線 + 評估指標** - 幫你整合成 GUI、CLI 或 Web API - 產生一份訓練報告(含訓練圖表與效果圖) 需要我幫你產生優化後的 `train.py` 或重新整理資料集架構嗎?你只要告訴我你目前的資料集規劃,我就能幫你改完整程式。