거의 알고리즘 일기장

Data Loader 사용법 본문

pytorch 사용법

Data Loader 사용법

건우권 2020. 9. 21. 14:57

 

 

 

In [10]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import numpy as np

import torchvision 
import torchvision.transforms as transforms
In [2]:
transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
In [4]:
trainset = torchvision.datasets.CIFAR10(root = './data',
                                          train = True,
                                          download = True,
                                          transform = transform)

testset = torchvision.datasets.CIFAR10(root = './data',
                                          train = False,
                                          download = True,
                                          transform = transform)
 
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
 
 
 
Files already downloaded and verified
In [5]:
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size = 8, shuffle=False, num_workers=2)
In [6]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
In [7]:
def imshow(img):
    img = img/2 +0.5 # unnormalize
    np_img = img.numpy()
    #ToTensor의 h x w x c 순에서
    #c x h x w 순으로 변경
    plt.imshow(np.transpose(np_img, (1,2,0)))
    
    print(np_img.shape)
    print((np.transpose(np_img, (1,2,0))).shape)
In [8]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
In [12]:
print(images.shape)
imshow(torchvision.utils.make_grid(images, nrow=4))
 
torch.Size([8, 3, 32, 32])
(3, 70, 138)
(70, 138, 3)
 
In [18]:
print(images.shape)
print(torchvision.utils.make_grid(images, nrow=4).shape)
print(torchvision.utils.make_grid(images).shape)
print(''.join('%5s' %classes[labels[j]] for j in range(8)))
 
torch.Size([8, 3, 32, 32])
torch.Size([3, 70, 138])
torch.Size([3, 36, 274])
 ship frog deer bird  cat deer  dog  cat
In [ ]:
 
반응형
Comments