Notice
Recent Posts
Recent Comments
Link
거의 알고리즘 일기장
network를 Sequential을 이용해서 간단하게 만들기 & 모델 저장,불러오기 본문
In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torch.autograd import Variable
import visdom
import torch.optim as optim
import torchvision.transforms as transforms
In [7]:
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
trainset = datasets.CIFAR10(root='./data',
train = True,
download = True,
transform = transform)
testset = datasets.CIFAR10(root='./data',
train = False,
download = True,
transform = transform)
In [13]:
class Net(nn.Module):
def __init__(self):
#가중치 초기화
super(Net, self).__init__()
#sequential을 사용하면 편하게 모델 만들기 가능
self.layer1 = nn.Sequential(
nn.Conv2d(3, 32, 5),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3),
nn.MaxPool2d(2)
)
self.layer2_1 = nn.Sequential(
nn.Conv2d(256, 512, 7, 1, 2),
nn.Conv2d(512, 64, 1),
nn.MaxPool2d(2)
)
self.layer2_2 = nn.Sequential(
nn.Conv2d(256, 512, 5, 1, 1),
nn.Conv2d(512, 64, 1),
nn.MaxPool2d(2)
)
self.layer2_3 = nn.Sequential(
nn.Conv2d(256, 512, 3),
nn.Conv2d(512, 64, 1),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(3* 64* 4 * 4, 1024),
nn.ReLU(True),
nn.Linear(1024, 10)
)
def forward(self, x):
print(x.data.shape)
x = self.layer1(x)
x1 = self.layer2_1(x)
x2 = self.layer2_2(x)
x3 = self.layer2_3(x)
#합쳐줌, B x C x H x W
# 0 1 2 3
#이건 채널을 기준으로 합침
x = torch.cat((x1, x2, x3), 1)
#B는 두고 나머지를 일자로 핌
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
In [14]:
#되는지 확인
a = torch.rand(1, 3, 32, 32)
print(a.shape)
a = Variable(a)
print(a.shape)
In [15]:
network = Net()
out = network(a)
print(out.shape)
In [16]:
print(out)
In [17]:
#모델 저장
torch.save(network.state_dict(), './cnn.pth')
In [18]:
#모델 불러오기
#모델 아키텍쳐 정의
model = Net()
#모델 아키텍쳐와 동일한 모델을 불러온다!!
model.load_state_dict(torch.load('./cnn.pth'))
Out[18]:
In [19]:
out = model(a)
print(out)
In [ ]:
반응형
'pytorch 사용법' 카테고리의 다른 글
visdom 사용법 (0) | 2020.09.22 |
---|---|
custom dataset사용방법 _ image의 경우 (0) | 2020.09.22 |
optim & criterion (0) | 2020.09.21 |
make NN (0) | 2020.09.21 |
Data Loader 사용법 (0) | 2020.09.21 |
Comments