CNN 모델 중 정말 조상님이라 할 수 있는 LeNet의 버전 중 LeNet-5를 직접 코드로 구현해보려고 한다.
LeNet-5 구조
저 숫자가 어떻게 나온거지?
저기 있는 사이즈, 채널에 대해 숫자가 어떻게 계산되어 나왔는지 정리를 해보았다.
각 레이어들 정리
Input size : 32X32 1채널
C1 : 28 * 28 커널 수 6개
S2 : 14*14 커널 수 6개
C3 : 10*10 커널 수 16개
S4 : 5*5 커널 수 16개
C5 : 120
F6 : 84
OUPUT : 10
계산을 위한 공식
P : Padding사이즈
FH : Filter(Kernal)의 Height
FW : Filter(Kernal)의 Width
S : Stride 크기
이제 정말 계산을 해보자
S2는 Sampling이므로 PoolingSize가 2*2 이여야 14*14로 나올 수 있고 커널 수는 유지
C3을 얻기 위해 다시 5*5 필터를 적용하면
(14 + 0 -5 + 1) = 10 * 10
S4는 Sampling이므로 PoolingSize가 2*2 이여야 5*5로 나올 수 있고 커널 수는 유지된다.
C5에서는 FC(Full Connection)이 적용된다.
우선 Convolution작업을 한다.
(5 + 0 - 5 + 1) = 1 * 1 이고 커널 수는 120으로 한다.
이후 FC작업을 통해 일렬로 배치를 하고 84, 10으로 줄여간다.
모델 코드 구현
이걸 구현한 코드를 살펴보자.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1)
self.fc1 = nn.Linear(in_features=120, out_features=84)
self.fc2 = nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.tanh(x) # 활성함수
x = F.max_pool2d(x, 2,2)
x = self.conv2(x)
x = F.tanh(x)
x = F.max_pool2d(x, 2,2)
x = self.conv3(x)
x = F.tanh(x)
x = x.view(-1, 120)
x = self.fc1(x)
x = F.tanh(x)
x = self.fc2(x)
x = F.tanh(x)
return x
net = Net()
print(net)
pytorch에서 모델은 nn.Module을 상속받아야한다. 상속받은 이후 init 함수에서는각 레이어를 정의한다. forward 함수에서는 forward propagation이 진행되는데 init에서 정의한 레이어가 하나씩 진행된다.
Conv2d에는 파라미터가 3개 들어있다. image의 input channel, output channel, kernel size 이다. 원래 뒤에 stride도 들어가는데 default가 1이라 생략되었다.
활성화 함수를 통해 출력 범위를 지정해서 각 확률을 정리해준다.
아직 의문점
논문에서는 활성화함수로 어떤걸 사용했는지 정확히 나와있지는 않다. 그러면 아무거나 사용해도 상관이 없을까?
아시는 분은 댓글로 부탁드립니다..
'Python > 딥러닝 (Deep-Learning)' 카테고리의 다른 글
Product에 Object Detection을 도입하고 싶은데 딥러닝은 하나도 모를때 읽으면 괜찮을만한 글(2) (0) | 2022.05.07 |
---|---|
Product에 Object Detection을 도입하고 싶은데 딥러닝은 하나도 모를때 읽으면 괜찮을만한 글(1) (2) | 2022.02.28 |
우리 EasyOCR로 한번 가자(2) (6) | 2021.04.18 |
우리 EasyOCR로 한번 가자(1) (2) | 2021.04.13 |
Tensorlfow Object Detection API 사용 중 발생한 에러 정리 (0) | 2021.03.25 |