거의 알고리즘 일기장

make NN 본문

pytorch 사용법

make NN

건우권 2020. 9. 21. 15:20
Untitled1
In [3]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import numpy as np

import torchvision 
import torchvision.transforms as transforms
In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
In [5]:
trainset = torchvision.datasets.CIFAR10(root = './data',
                                          train = True,
                                          download = True,
                                          transform = transform)

testset = torchvision.datasets.CIFAR10(root = './data',
                                          train = False,
                                          download = True,
                                          transform = transform)
Files already downloaded and verified
Files already downloaded and verified
In [6]:
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size = 8, shuffle=False, num_workers=2)
In [7]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
In [8]:
def imshow(img):
    img = img/2 +0.5 # unnormalize
    np_img = img.numpy()
    #ToTensor의 h x w x c 순에서
    #c x h x w 순으로 변경
    plt.imshow(np.transpose(np_img, (1,2,0)))
    
    print(np_img.shape)
    print((np.transpose(np_img, (1,2,0))).shape)
In [9]:
imgs = 0
for n, (img, labels) in enumerate(trainloader):
    print(n, img.shape, labels.shape)
    imgs = img
    break
0 torch.Size([8, 3, 32, 32]) torch.Size([8])
In [10]:
net = nn.Conv2d(3, 5, 5)
In [11]:
out1 = net(Variable(imgs))
print(out1.shape)
torch.Size([8, 5, 28, 28])
In [12]:
net2 = nn.Conv2d(5, 10, 5)
In [13]:
out2 = net2(out1)
print(out2.shape)
torch.Size([8, 10, 24, 24])
In [15]:
class my_network(nn.Module):
    def __init__(self):
        super(my_network, self).__init__()
        self.net_1 = nn.Conv2d(3, 5, 5)
        self.net_2 = nn.Conv2d(5, 10, 5)
    def forward(self, x):
        x = self.net_1(x)
        x = self.net_2(x)
        return x
In [16]:
imgs = 0
for n, (img, labels) in enumerate(trainloader):
    print(n, img.shape, labels.shape)
    imgs = img
    break
0 torch.Size([8, 3, 32, 32]) torch.Size([8])
In [17]:
my_net = my_network()
In [18]:
out = my_net(Variable(imgs))
print(out.shape)
torch.Size([8, 10, 24, 24])
In [ ]:
 
반응형
Comments