最近想到
如果 pytorch 要顯示模型樹狀圖架構
類似於 keras 的 plot_model
要怎麼做呢?
找一找發現以下方法可以
我有用 resnet18 做一個範例給大家參考
主要是要先安裝 graphviz
https://graphviz.org/download/
直接下載到特定路徑就可以
然後在執行 python 時候特別指定就可以
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# 用 resnet18 來示範
class Resnet18Models(nn.Module):
def __init__(self, num_classes=2, pretrained=False):
super(Resnet18Models, self).__init__()
self.base_model = timm.create_model('resnet18', pretrained=pretrained).cuda()
# 取到倒數第二層 特徵層
self.base_model = nn.Sequential(*list(self.base_model.children())[:-1]).cuda()
self.bn1 = nn.BatchNorm1d(512).cuda()
self.relu = torch.relu
self.fc = nn.Linear(512, num_classes).cuda()
def forward(self, x):
# 資料進入 resnet18
x = self.base_model(x)
x = self.bn1(x)
x = self.relu(x)
x = self.fc(x)
# 分類架構
x = F.softmax(x, dim=1)
return x
# 載入 torchviz
from torchviz import make_dot
os.environ["PATH"] += os.pathsep + 'C:\graphviz-2.44.1-win32/Graphviz/bin/' # 安裝graphviz的路徑
# 新增模型
mainModel = Resnet18Models()
# 決定輸入大小
inputSize = (3, 128, 128)
# 產生輸入的示範資料
inputDatas = torch.zeros((2, inputSize[0], inputSize[1], inputSize[2]), requires_grad=False).cuda()
# 輸入模型
output = mainModel(inputDatas)
# make_dot 用輸出資料反推模型
modelImg = make_dot(output, params=dict(mainModel.named_parameters()), show_saved=True)
# 顯示架構
modelImg.view()
成功則顯示以下圖片
顯示出來的圖片是這樣
感覺跟keras 的 plot_model 出來的圖還是不太一樣![]()
感覺沒這麼易讀
參數部分感覺是依照每一層的計算方式去列出來的
連 fc weight 跟 bias 參數都有顯示
算是很詳細
但感覺過於複雜?
不過整體架構還是可以理解的
給大家參考囉


留言板
歡迎留下建議與分享!希望一起交流!感恩!