# PyTorch NaN問題排除
某些loss function,例如ciou或giou在計算梯度時偶然會產生NaN,為了排除問題,可以採用以下方法:
1. torch.autograd.set_detect_anomaly(True):會在backward時檢查NaN或其他異常的存在,也可以寫成:
```python=3.6
with torch.autograd.detect_anomaly():
loss.backward()
```
2. register_hook、register_backward_hook:可以修改梯度。
3. torch.nan_to_num:把NaN替換成指定數字。
4. torch.clamp:通常NaN出現在sqrt、exp之類地方,因值太大或太小而溢位,或是atan、atan2出現0值,可以用clamp把值限制在一個區間內。以下是一些解法:
```python=3.6
#原本造成NaN的計算
alpha = torch.atan2(xyz[..., 1], xyz[..., 0])
```
改成
```python=3.6
epsilon = torch.ones_like(xyz[..., 0]) * 1e-5
xyz_tmp = torch.where(xyz[..., 0]==0, xyz[..., 0], xyz[..., 0]+epsilon)
alpha = torch.atan2(xyz[..., 1], xyz_tmp)
```
除此之外,可能是因為學習率設太大,ciou或giou就算做了再多防護措施,仍會很神奇的會因學習率太大而出事。