最近遇到 PyTorch 訓練出現 loss 為 nan

蠻奇怪的

第一輪模型訓練都沒問題

到第二輪就變成nan

檢查輸入的影像資料沒有nan

輸出的output 卻為 nan

 

torch.isnan(images).any() =  tensor(False, device='cuda:0')
outputs =  tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]], device='cuda:0', grad_fn=<AddmmBackward0>)

 

 

還有一個可能就是模型參數有 nan

設下檢查點

 

for name, param in model.named_parameters():
    if torch.isnan(param).any():
        print(f"⚠️ {name} 包含 NaN")
    if torch.isinf(param).any():
        print(f"⚠️ {name} 包含 Inf(無窮大)")

發現還真有參數是 nan

 

 

⚠️ text_projection 包含 NaN

查看一下程式碼

 

 

self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

發現他初始使用 torch.empty 來新增參數

 

可能就是這個原因

改成 torch.randn 或 torch.zeros 試試

 

發現還真的問題就解決了...

但不知道為什麼這問題在我的PC沒問題

在DGX上面有問題

有時候這種問題還真不好找

給大家參考囉