반응형
목차
더보기
What is a neural network
Components of a neural network
Activation function
Weights
Bias
Build a neural Network
Get a hardware device for training
Define the class
Weight and Bias
Model layers
nn.Flatten
nn.Linear
nn.ReLU
nn.Sequential
nn.Softmax
Model parameters
What is a neural network
- 신경망: 레이어로 연결된 뉴런의 모음
- 뉴런: 문제를 해결하기 위해 함께 작동하는 작은 컴퓨팅 단위
- 레이어 유형: input / hidden / output
- 각 레이어에는 입력 레이어를 제외, 많은 뉴런이 포함되어 있음
- 신경망은 인간 두뇌의 정보 처리 방식을 모방
Activation function
- 신경망 노드 또는 인공 뉴런의 출력에 적용되어 해당 노드의 출력을 결정하는 수학 함수
- 뉴런을 활성화할지 여부를 결정
- 신경망 계산에는 활성화 함수 적용이 포함됨
- 뉴런이 활성화되면 해당 뉴런의 입력이 중요하다는 의미
- 다양한 유형의 활성화 함수 존재
- 원하는 출력에 따라 다르게 선택 가능
- 모델에 비선형성 추가
Weights
- 훈련 과정에서 학습되는 매개변수
- 네트워크의 출력과 예상 출력값에 얼마나 근접하는지 영향을 미침
- 입력은 뉴런의 가중치로 곱해지고 출력은 관찰(observed, 현재 뉴런 또는 계층 내에서 사용되거나 분석됨)되거나 다음 계층으로 전달(passed, 추가 처리를 위해 신경망의 다음 계층으로 전송)됨
- 레이어의 모든 뉴런에 대한 가중치는 텐서로 구성됨
Bias
- 활성 함수를 통과하기 전에 입력의 가중 합에 추가되는 추가 항
- 활성화 함수의 출력과 예상 출력의 차이를 보정
- 바이어스가 낮음은 네트워크의 출력에 대해 확신이 높음
- 바이어스가 높음은 네트워크의 출력에 대해 확신이 낮음
Build a neural Network
- 신경망: 다른 모듈(계층)로 구성된 모듈
- torch.nn 네임스페이스는 신경망을 위한 빌딩 블록을 제공
- PyTorch의 모든 모듈은 nn.Module의 하위 클래스
- 이 중첩 구조를 통해 복잡한 아키텍처를 쉽게 구축하고 관리 가능
- 키워드: 신경망, 레이어, 모듈, torch.nn, nn.Module
%matplotlib inline
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
Get a hardware device for training
torch.cuda를 이용해 GPU 사용 가능 여부 확인, 아니라면 CPU 사용
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))
Using cuda device
Define the class
- nn.Module을 서브클래싱하여 신경망 정의
- __init__ 메서드에서 신경망 계층 초기화
- forward method에서 입력 데이터에 대한 작업을 구현
- 입력 레이어에는 28x28 또는 784개의 피쳐/픽셀
- 첫 번째 선형 모듈:
입력: 784개의 피쳐
출력:512개의 피쳐 - ReLU 활성화 함수는 1~3 번째 리니어 모듈의 변환에 적용
- 두 번째 선형 모듈:
입력: 첫 번째 히든 레이어의 출력인 512개의 피쳐
출력: 512개의 피쳐 - 세 번째 선형 모듈:
입력: 두 번째 히든 레이어의 출력인 512개의 피쳐
출력: 10개의 클래스
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
(5): ReLU()
)
)
- forward method 및 기타 백그라운드 작업을 실행하기 위해 입력 데이터를 모델에 전달
- model.forward() 직접 호출 금지
- 입력과 함께 모델 호출 시 각 클래스에 대한 raw 예측 값이 있는 10차원 텐서 반환
- nn.Softmax의 인스턴스를 통해 prediction density 얻음
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
Predicted class: tensor([2], device='cuda:0')
Weight and Bias
- nn.Linear 모듈은 각 레이어에 대해 weight, bias 초기화
- 초기화한 값은 Tensors에 저장됨
print(f"First Linear weights size: {model.linear_relu_stack[0].weight.size()} \n")
print(f"First Linear weights: {model.linear_relu_stack[0].weight} \n")
print(f"First Linear weights size: {model.linear_relu_stack[0].bias.size()} \n")
print(f"First Linear weights: {model.linear_relu_stack[0].bias} \n")
First Linear weights size: torch.Size([512, 784])
First Linear weights: Parameter containing:
tensor([[ 0.0345, 0.0125, -0.0064, ..., -0.0289, -0.0272, 0.0049],
[-0.0083, 0.0006, 0.0139, ..., 0.0288, 0.0230, -0.0266],
[ 0.0202, 0.0218, -0.0205, ..., -0.0233, -0.0014, 0.0225],
...,
[-0.0056, 0.0278, 0.0311, ..., -0.0061, 0.0304, 0.0324],
[-0.0209, 0.0126, -0.0110, ..., 0.0329, -0.0086, -0.0268],
[ 0.0089, -0.0237, 0.0225, ..., -0.0266, 0.0019, 0.0306]],
device='cuda:0', requires_grad=True)
First Linear weights size: torch.Size([512])
First Linear weights: Parameter containing:
tensor([-3.3893e-02, 1.1856e-02, -2.9198e-03, 7.7550e-05, -3.6560e-03,
-6.9579e-03, -1.2820e-02, -1.3843e-02, -3.2885e-02, -3.0706e-02,
1.8154e-02, -2.2628e-02, -1.9467e-02, -1.1431e-02, 3.4236e-02,
1.4550e-02, 3.5565e-02, 6.1036e-03, 1.5597e-02, -3.4587e-02,
8.7269e-03, -1.1294e-02, -1.6513e-02, -1.8476e-02, -1.3202e-02,
8.8488e-03, 9.9392e-04, 5.2359e-03, -2.4069e-03, -1.5483e-02,
2.5843e-02, 1.0256e-02, -2.9113e-02, -2.5873e-02, -1.9944e-02,
-2.6365e-02, 2.5999e-02, 2.7609e-02, 1.5273e-02, -3.3479e-02,
-6.3627e-03, 2.9811e-02, -1.2839e-02, -5.0589e-03, -3.1901e-02,
3.3307e-02, -4.5675e-03, -2.6921e-02, 2.1544e-02, 7.8810e-03,
-2.4426e-02, -2.1167e-03, -3.1412e-02, -1.5664e-02, 1.6812e-02,
-1.7104e-02, 3.4536e-02, -6.5952e-03, 2.6700e-02, -1.2066e-02,
6.5335e-03, 2.8978e-02, 1.8800e-02, 1.4221e-02, 3.2925e-02,
-1.9482e-02, -2.4511e-02, 3.0306e-04, -2.9231e-04, -2.4209e-02,
2.5072e-02, -8.5218e-03, -2.6121e-02, 2.6395e-02, 1.9129e-02,
1.4820e-02, -1.2511e-02, 3.5386e-02, 1.4013e-02, -2.6863e-02,
-4.6941e-04, 2.2849e-02, -1.6215e-02, 3.4199e-02, -1.3711e-02,
2.8474e-02, -3.1933e-02, -3.3415e-02, -1.6178e-02, 1.4155e-02,
-6.5667e-03, 2.3611e-03, -3.2452e-02, 8.6236e-03, -1.5981e-03,
1.1232e-03, 1.1975e-02, -2.3774e-03, 1.7887e-02, 3.5201e-02,
-1.1175e-02, -3.5412e-02, 1.2975e-02, -1.2248e-02, 2.7172e-02,
-3.2291e-02, 1.5862e-02, 2.5353e-02, 3.1547e-02, -3.1174e-02,
-2.1079e-02, 1.9473e-02, 1.6254e-03, -2.6573e-02, 1.8785e-02,
2.9551e-02, -2.4120e-02, -1.6095e-02, -1.1159e-02, -3.2816e-02,
1.4767e-02, 1.6708e-02, 2.6993e-02, -5.7211e-03, -1.8452e-02,
3.2536e-02, 2.8532e-02, 2.2516e-04, 2.1189e-02, -3.3233e-02,
3.0423e-02, 3.3523e-02, -1.0432e-02, -3.2602e-02, -2.8546e-02,
-6.4327e-03, 6.9642e-03, -1.4304e-02, 6.8462e-03, 5.0203e-03,
2.1585e-02, 3.1545e-02, -1.0934e-03, -1.9406e-02, -2.9421e-02,
-1.7411e-02, -2.0412e-02, 2.5102e-02, 1.2547e-02, -3.0986e-02,
1.6634e-02, -2.4984e-02, -2.6488e-02, 3.4600e-03, 3.2208e-02,
1.4877e-02, 3.5410e-02, -3.1163e-02, 2.8038e-02, -5.1891e-03,
5.3359e-03, -2.6971e-02, -3.0791e-02, 9.9291e-03, 2.0621e-02,
1.9868e-03, 1.9627e-02, 1.0046e-02, -2.4433e-02, 2.8893e-02,
1.1766e-02, 1.0414e-02, 2.3457e-03, -1.2746e-02, -1.0818e-02,
-3.2967e-02, -1.4925e-02, -3.2252e-03, -2.2941e-03, 9.0597e-04,
6.9115e-03, -2.1735e-03, -2.4123e-02, -3.3446e-02, 1.0484e-03,
3.9039e-03, 2.3986e-02, -3.4971e-02, 1.3492e-02, 3.0857e-02,
-2.8267e-02, -3.3370e-02, -2.1649e-02, -3.0752e-02, -5.1667e-03,
3.0889e-02, -2.6058e-02, -3.4022e-02, 8.5368e-03, -1.3550e-02,
4.6412e-03, 2.8140e-02, 2.9436e-02, 1.6564e-02, 6.6783e-03,
-2.2961e-02, -3.0092e-02, -1.2963e-02, 1.6480e-02, -1.8736e-02,
-2.1679e-04, -1.8737e-02, -2.3382e-02, -1.0879e-02, -2.6364e-02,
1.4948e-02, -2.4392e-02, 9.0217e-03, -3.0474e-02, 2.4682e-02,
-1.5634e-02, 2.7837e-02, -2.2873e-04, 3.5140e-02, 3.0557e-02,
-1.0857e-02, -7.7954e-03, 2.0694e-02, 2.0313e-02, -2.7490e-02,
-2.2558e-02, 2.8936e-02, 3.4033e-02, 2.6560e-02, 3.0019e-02,
2.9273e-02, -2.7359e-02, -1.2905e-02, -2.0313e-02, 9.5828e-03,
8.8425e-03, -1.7195e-02, -2.0704e-02, 1.0450e-03, 2.5950e-02,
1.9036e-02, 1.4443e-02, 2.9559e-02, 2.9179e-03, -1.0265e-02,
2.3461e-02, -1.3738e-02, -3.0526e-02, 6.4547e-03, 3.4369e-03,
-9.6529e-03, 2.7027e-02, -1.5767e-02, -2.1840e-02, 3.5234e-02,
1.3105e-03, -2.4140e-02, -1.6030e-02, 5.1552e-03, -3.3987e-02,
-3.0117e-02, -8.1727e-03, -2.6289e-02, 3.4823e-02, 3.1985e-02,
-9.3816e-03, 2.1257e-02, 1.2116e-02, -1.3521e-02, -4.1174e-03,
9.8654e-03, -1.1240e-02, -2.1743e-02, -2.4009e-02, -2.7295e-02,
-1.6306e-02, 1.3614e-02, 3.1704e-02, -3.0783e-02, -5.3847e-03,
-1.0096e-03, 1.3106e-02, 2.3826e-02, 1.7730e-03, 9.5286e-03,
-3.3278e-02, 1.8588e-02, 1.8140e-02, 2.7753e-02, -3.4654e-02,
3.2467e-02, 5.3144e-03, 2.2868e-02, -3.2193e-03, -8.4300e-03,
-2.9408e-02, -3.1121e-02, -3.2734e-03, 1.4886e-02, 2.5475e-02,
2.4241e-02, -2.7010e-02, -2.7500e-02, -8.5497e-03, -2.0963e-02,
1.0632e-02, 8.7818e-03, 2.6593e-02, 1.5180e-02, 3.4387e-03,
2.1794e-02, 2.5454e-02, 9.5285e-03, -2.9737e-02, -1.0218e-02,
2.4998e-02, 2.8257e-02, -2.1038e-02, 7.7758e-03, 8.9003e-03,
2.5068e-02, -2.8536e-02, 2.4896e-02, -1.5650e-02, 3.2877e-02,
-1.6008e-02, -9.3496e-03, 4.9157e-03, 8.2707e-04, -5.9403e-03,
-1.5385e-03, 5.3282e-03, -9.3790e-03, 3.4493e-02, 1.8209e-02,
-2.6138e-02, 1.3822e-02, 3.4545e-02, -2.5219e-02, -1.9581e-02,
2.1199e-02, -1.4929e-02, 2.1148e-02, -1.5851e-02, 2.7479e-02,
-2.8841e-02, -5.9577e-03, 3.2018e-02, -2.7280e-02, 1.3414e-03,
1.6764e-03, -1.0262e-02, -5.5625e-03, -4.0343e-03, 1.0875e-02,
8.6537e-03, 2.2670e-02, -1.8073e-02, -1.1438e-02, -3.1738e-02,
-2.8944e-02, -3.1937e-02, 2.5979e-02, 2.4261e-02, -2.2968e-02,
6.3852e-03, 1.9372e-02, -1.9294e-02, 1.1858e-02, -9.6578e-03,
1.0866e-02, 2.6905e-03, 2.7356e-02, -2.8320e-02, 6.9093e-03,
9.8558e-03, 2.3251e-02, 7.5669e-03, 4.5801e-03, 1.9185e-02,
1.5209e-02, -5.4668e-03, -2.8662e-02, 8.1289e-03, -3.1957e-02,
-9.9186e-03, -8.0599e-03, 1.0866e-02, -1.7790e-02, 1.0896e-02,
-1.6066e-02, 2.4573e-02, -2.3933e-02, -6.1969e-03, 9.0088e-03,
2.1756e-02, 3.3143e-02, 3.4638e-02, 3.1814e-02, 4.8598e-03,
3.3483e-02, -1.1576e-03, -2.6452e-02, 3.2410e-02, 1.9123e-03,
1.1893e-02, -1.2665e-02, -2.3922e-02, -3.4212e-02, 1.7064e-02,
-5.5048e-03, -2.6277e-02, 2.8375e-02, -2.5197e-02, 3.2328e-02,
3.1944e-02, 1.7533e-02, -5.4455e-03, 1.4251e-02, -7.3559e-03,
-2.6303e-02, -1.5736e-02, 1.5723e-05, 1.7847e-02, -3.5334e-02,
-2.8193e-02, -6.2669e-03, 1.3890e-02, -3.5535e-02, 2.1055e-02,
-2.6242e-02, 2.0644e-02, -2.7883e-02, 2.9360e-02, 7.8302e-03,
-2.6974e-02, 2.5662e-02, -2.3669e-02, 9.9982e-03, -2.5166e-03,
2.6066e-02, 1.8942e-02, 3.2216e-02, 2.5957e-02, 1.6543e-02,
-7.1712e-03, 8.9232e-03, 1.0911e-03, 1.2185e-02, 9.3086e-03,
-1.6120e-02, 2.5654e-02, -2.1371e-02, -2.8406e-02, -2.8229e-02,
1.1333e-02, -3.4800e-02, 2.0971e-02, -5.9193e-03, -1.7254e-02,
-2.7244e-02, -2.6970e-02, -2.7199e-02, 1.0899e-02, -1.6956e-02,
2.3275e-02, 1.3890e-02, -1.4555e-02, -9.0339e-04, -4.6278e-03,
-7.1598e-03, -2.1138e-02, -1.0034e-02, -1.9791e-04, 3.5472e-02,
-5.5492e-04, -2.6190e-03, 1.0888e-02, -2.2592e-02, 8.5627e-03,
-2.3297e-02, 5.3967e-03, 1.1539e-02, 2.3211e-03, -1.1103e-02,
6.7959e-03, 2.0121e-02, -3.5005e-02, -1.4148e-02, 2.8122e-02,
9.9652e-03, 6.7002e-03, 3.1070e-02, 3.9661e-04, -1.6871e-02,
-3.7969e-03, 2.8706e-02, 1.5999e-02, 2.8321e-02, -6.6723e-03,
6.6234e-03, -1.4215e-02, 4.7967e-03, 2.1303e-02, -4.9500e-03,
1.1886e-02, -3.2711e-02], device='cuda:0', requires_grad=True)
Model layers
- FashionMNIST 모델의 레이어 분석 위해 3장의 28*28 이미지 가져옴
input_image = torch.rand(3,28,28)
print(input_image.size())
torch.Size([3, 28, 28])
nn.Flatten
- 입력 데이터를 1차원 텐서로 평면화하는 모듈
- 완전 연결 계층(fully connected layer)으로 전달하기 전에 입력 데이터의 형태를 변경하는 데 자주 사용됨
- nn.Flatten 레이어를 초기화하여 각 2D 28x28 이미지를 784 픽셀 값의 연속적인 배열로 변환
- (미니배치 차원(dim=0)이 유지됨)
- 각 픽셀은 신경망의 입력 레이어로 전달됨
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())
torch.Size([3, 784])
nn.Linear
- 저장된 가중치와 편향을 사용하여 입력에 선형 변환을 적용하는 모듈
- 완전 연결 계층(fully connected layer)로 자주 사용됨
- 입력 레이어의 각 픽셀의 그레이스케일 값은 계산을 위해 숨겨진 레이어의 뉴런에 연결
- 변환에 사용되는 계산: weight*input + bias
layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size())
torch.Size([3, 20])
nn.ReLU
- ReLU(Rectified Linear Unit, 정류된 선형 단위) 활성화 함수를 요소별로 입력 데이터에 적용하는 모듈
- 0 if x < 0 else x
- 활성화 함수로 자주 사용됨
- Non-linear activation: 모델의 입력과 출력 사이에 복잡한 매핑을 만드는 것
- 선형 변환 후 적용되어 신경망의 학습에 도움
- ReLU 이외에도 많음
print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")
Before ReLU: tensor([[-1.5071e-01, -4.2613e-01, 1.3321e-01, -3.1527e-01, 3.5439e-01,
-3.5691e-01, -4.6418e-01, 1.0611e+00, 2.8271e-01, 5.6982e-02,
5.6526e-01, -2.3586e-01, 7.0047e-04, -2.7395e-01, -2.7888e-01,
4.3065e-02, -4.6131e-01, -1.5210e-01, -2.2935e-01, -7.2748e-01],
[-1.4575e-01, -3.4718e-01, 8.3041e-02, -3.9843e-01, 9.0595e-01,
-1.2193e-01, -3.6724e-01, 7.8541e-01, 2.0537e-02, 1.4214e-01,
1.6115e-01, 4.9334e-02, -1.4593e-01, -3.6390e-01, -4.5012e-01,
8.8599e-02, -3.6883e-01, -3.0174e-01, -4.6783e-01, -5.8375e-01],
[-1.1554e-01, -8.4348e-02, 1.8244e-01, -2.1920e-01, 4.7204e-01,
-3.7186e-01, -3.7802e-01, 8.9063e-01, 2.6306e-01, 3.1152e-01,
6.0135e-01, -9.6417e-02, -2.8857e-01, -2.7803e-01, -4.1772e-01,
1.3958e-01, -1.2683e-01, -2.6364e-01, -4.6738e-01, -1.1816e+00]],
grad_fn=<AddmmBackward0>)
After ReLU: tensor([[0.0000e+00, 0.0000e+00, 1.3321e-01, 0.0000e+00, 3.5439e-01, 0.0000e+00,
0.0000e+00, 1.0611e+00, 2.8271e-01, 5.6982e-02, 5.6526e-01, 0.0000e+00,
7.0047e-04, 0.0000e+00, 0.0000e+00, 4.3065e-02, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 8.3041e-02, 0.0000e+00, 9.0595e-01, 0.0000e+00,
0.0000e+00, 7.8541e-01, 2.0537e-02, 1.4214e-01, 1.6115e-01, 4.9334e-02,
0.0000e+00, 0.0000e+00, 0.0000e+00, 8.8599e-02, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.8244e-01, 0.0000e+00, 4.7204e-01, 0.0000e+00,
0.0000e+00, 8.9063e-01, 2.6306e-01, 3.1152e-01, 6.0135e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3958e-01, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00]], grad_fn=<ReluBackward0>)
nn.Sequential
- 여러 신경망 계층의 순차적 실행을 허용하는 모듈의 컨테이너
- 단일 클래스에 여러 nn.Module을 쌓아 단일 통합 모듈을 만들 수 있음
- 데이터는 정의된 모듈의 순서대로 전달됨
- seq_modules 과 같은 빠른 네트워크 구성 가능
seq_modules = nn.Sequential(
flatten,
layer1,
nn.ReLU(),
nn.Linear(20, 10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)
print(logits)
tensor([[ 0.0136, 0.1250, -0.2537, 0.3161, -0.0241, -0.0789, -0.1728, 0.1266,
-0.0040, -0.2851],
[ 0.1931, 0.0656, -0.2310, 0.3074, -0.1630, -0.0271, -0.0584, 0.2903,
0.0406, -0.2842],
[ 0.0712, 0.1024, -0.2851, 0.3011, -0.1147, -0.0587, -0.1740, 0.1596,
0.0162, -0.2641]], grad_fn=<AddmmBackward0>)
nn.Softmax
- 입력 데이터에 요소별로 softmax 함수를 적용하는 모듈
- 출력 점수를 확률로 변환하기 위해 다중 클래스 분류를 위한 신경망의 마지막 계층으로 자주 사용됨
- 마지막 출력 계층에서만 사용됨
- 결과는 각 클래스에 대한 모델의 예측 밀도를 나타내는 값이며 [0, 1]
- dim 파라미터는 합이 1이 되어야 하는 차원을 말함
- 확률이 가장 높은 노드는 원하는 출력에 대한 예측값
- 신경망의 마지막 linear layer는 로짓 반환(로짓으로 해석, softmax가 확률로 해석되기 때문)
softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)
print(pred_probab)
tensor([[0.1022, 0.1142, 0.0782, 0.1383, 0.0984, 0.0932, 0.0848, 0.1144, 0.1004,
0.0758],
[0.1174, 0.1034, 0.0768, 0.1316, 0.0822, 0.0942, 0.0913, 0.1294, 0.1008,
0.0728],
[0.1083, 0.1117, 0.0758, 0.1363, 0.0899, 0.0951, 0.0847, 0.1183, 0.1025,
0.0774]], grad_fn=<SoftmaxBackward0>)
Model parameters
- 신경망의 많은 레이어는 parameterized -> 훈련 중 최적화할 수 있는 weight 및 bias 존재
- nn.Module을 서브클래싱(구현되어 있는 것을 상속)하면 모델 객체 내부에 정의된 모든 필드 자동으로 추적
- 모델의 parameters() / named_parameters() 메서드 사용하여 모든 매개변수에 접근 가능
print("Model structure: ", model, "\n\n")
for name, param in model.named_parameters():
print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")
Model structure: NeuralNetwork(
(flatten): Flatten()
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
(5): ReLU()
)
)
Layer: linear_relu_stack.0.weight | Size: torch.Size([512, 784]) | Values : tensor([[-0.0320, 0.0326, -0.0032, ..., -0.0236, -0.0025, -0.0175],
[ 0.0180, 0.0271, -0.0314, ..., -0.0094, -0.0170, -0.0257]],
device='cuda:0', grad_fn=<SliceBackward>)
Layer: linear_relu_stack.0.bias | Size: torch.Size([512]) | Values : tensor([-0.0134, 0.0036], device='cuda:0', grad_fn=<SliceBackward>)
Layer: linear_relu_stack.2.weight | Size: torch.Size([512, 512]) | Values : tensor([[-0.0262, 0.0072, -0.0348, ..., -0.0374, 0.0345, 0.0374],
[ 0.0439, -0.0101, 0.0218, ..., -0.0419, 0.0212, -0.0081]],
device='cuda:0', grad_fn=<SliceBackward>)
Layer: linear_relu_stack.2.bias | Size: torch.Size([512]) | Values : tensor([ 0.0131, -0.0289], device='cuda:0', grad_fn=<SliceBackward>)
Layer: linear_relu_stack.4.weight | Size: torch.Size([10, 512]) | Values : tensor([[ 0.0376, -0.0359, -0.0329, ..., -0.0057, 0.0040, 0.0307],
[-0.0196, -0.0440, 0.0250, ..., 0.0335, 0.0024, -0.0207]],
device='cuda:0', grad_fn=<SliceBackward>)
Layer: linear_relu_stack.4.bias | Size: torch.Size([10]) | Values : tensor([-0.0287, 0.0321], device='cuda:0', grad_fn=<SliceBackward>)
반응형