Pytorch/튜토리얼
[PyTorch] 공식 문서 Learn the Basics 요약 - 2. Dataset and DataLoaders, Normalization
대두코기
2023. 1. 16. 18:30
반응형
목차
- 데이터셋 및 데이터 로더
- 데이터셋 로드
- 데이터셋 반복 및 시각화
- DataLoaders를 사용한 학습을 위한 데이터 준비
- DataLoader를 통해 반복
- 정규화
데이터셋 및 데이터 로더 (Datasets and Dataloaders)
- PyTorch에서는 DataLoader와 Dataset 두 가지 데이터 프리미티브를 제공하여 데이터셋 코드를 모델 훈련 코드와 분리하여 읽기 쉽고 모듈화 가능
- Dataset은 샘플과 해당 레이블을 저장하고, DataLoader는 Dataset 주위에 이터러블을 감싸 샘플에 쉽게 접근 가능
- Dataset은 개별 데이터 항목을 검색하도록 설계됨
- DataLoader는 데이터 모음을 처리하도록 설계됨
- PyTorch는 이미지, 텍스트, 오디오 등 다양한 종류의 데이터셋을 제공하여 사용 가능
데이터셋 로드 (Loading a dataset)
- Fashion-MNIST 데이터셋을 TorchVision에서 로딩할 것
- Fashion-MNIST는 60,000개의 훈련 예제와 10,000개의 테스트 예제로 구성된 Zalando 아티클 이미지 데이터셋
- 각 예제는 28x28 크기의 흑백 이미지와 10 클래스 중 하나에 해당하는 레이블을 포함
- 각 이미지는 28픽셀 x 28픽셀로 784픽셀로 구성
- 10 클래스는 이미지의 종류를 알려줌 (예: T-shirt/top, Trouser, Pullover, Dress, Bag, Ankle boot 등)
- 흑백 이미지는 0~255 사이의 값으로 흑과 백의 강도를 측정 (예: 흰색은 0, 검은색은 255)
- root: 학습/테스트 데이터가 저장되는 경로
- train: 훈련/테스트 데이터 지정
- download=True: 루트에서 사용할 수 없는 경우 인터넷에서 데이터 다운로드
- transform/target_transform: 피쳐 및 레이블 변환을 지정
%matplotlib inline
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Using downloaded and verified file: data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Using downloaded and verified file: data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Using downloaded and verified file: data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Using downloaded and verified file: data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
데이터셋 반복 및 시각화 (Iterating and Visualizing the Dataset)
Datasets를 리스트처럼 수동으로 인덱싱 가능 : training_data[index]
시각화를 위해 matplotlib 사용
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
Dataloaders를 사용한 학습을 위한 데이터 준비 (Preparing your data for training with DataLoaders)
- Dataset은 한번에 하나의 샘플씩 데이터셋의 피쳐와 레이블을 가져옴
- 모델을 훈련할 때는 일반적으로
- "minibatch"로 샘플을 전달
- 에포크마다 데이터를 재셔플링하여 모델 과적합을 줄임
- Python의 멀티프로세싱을 사용하여 데이터 추출을 가속화
- 머신러닝에서는 데이터셋의 특징과 레이블을 지정해야 함
- 피쳐는 입력, 레이블은 출력
- 피쳐를 사용하여 모델을 훈련하고 레이블을 예측하도록 함
- 레이블은 10가지 클래스 종류: T-shirt, Sandal, Dress 등
- 피쳐는 이미지 픽셀에서 패턴
- DataLoader는 이러한 복잡성을 추상화하여 사용하기 쉬운 API로 제공
- DataLoader를 사용할 때는 다음 파라미터를 설정해야 함:
- data: 모델을 훈련하는 데 사용되는 훈련 데이터와 모델을 평가하는 데 사용되는 테스트 데이터
- batch_size: 각 배치에서 처리할 레코드의 수
- shuffle: 인덱스별로 무작위의 샘플 섞는지 여부
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
Dataloader를 통한 반복 (Iterate through the DataLoader)
- Dataloader에 데이터셋을 로딩했고, 필요할 때마다 데이터셋을 반복할 수 있음
- 각 반복은 train_features와 train_labels(각각 batch_size=64의 피쳐와 레이블을 포함하는) 배치를 반환
- shuffle=True를 지정했기 때문에, 모든 배치를 반복한 후 데이터가 섞임 (데이터 로딩 순서를 더 세분화할 수 있음)
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 6
정규화 (Normalization)
- 정규화란 데이터 전처리 기술
- 정규화는 각 피쳐에서 동일한 학습에 대한 기여를 보장하기 위해 데이터를 scale하거나 transform함
- 정규화는 피쳐 간의 distinction을 왜곡하지 않고 데이터의 범위를 변경함
- 정규화는 다음을 방지하는데 도움을 줌
- 예측 정확도 감소
- 모델이 학습하기 어려움
- 피쳐 데이터의 범위의 바람직하지 않은 분포
반응형