web-dev-qa-db-fra.com

Comment visualiser un filet dans Pytorch?

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
Epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

Je veux visualiser resnet à partir des modèles Pytorch. Comment puis-je le faire? J'ai essayé d'utiliser torchviz mais cela donne une erreur:

'ResNet' object has no attribute 'grad_fn'
23
raaj

make_dot attend une variable (c'est-à-dire un tenseur avec grad_fn), pas le modèle lui-même.
essayer:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
0
Shai

Vous pouvez jeter un œil à PyTorchViz ( https://github.com/szagoruyko/pytorchviz ), "Un petit paquet pour créer des visualisations des graphiques et des traces d'exécution PyTorch."

Example PyTorchViz visualization

0
David J.