https://growth-coder.tistory.com/245
이전 포스팅에서 CNN 모델에 대해 배우고 MNIST 손글씨 데이터를 학습시켜 보았다.
이번 포스팅에서는 웹 캠을 통해 읽어온 이미지를 모델로 보내서 손 글씨 숫자를 인식하는 코드를 작성해보려고 한다.
우선 학습한 모델을 저장하고 불러오는 방법부터 알아보자.
학습한 모델을 저장할 때는 모델 자체를 저장하는 방법이 있고 학습한 모델의 파라미터만 저장하는 방법이 있다.
두 방법 모두 알아보자.
개발 환경은 PyCharm을 사용하였다.
모델 저장 및 불러오기
파라미터만 저장
학습을 마치고 나서 모델 안의 state_dict 파일을 저장한다.
# 파라미터만 저장
torch.save(
obj=model.state_dict(),
f='cnn_parameters.pth'
)
이후 모델의 파라미터를 불러와야 하는데 이 때 모델 클래스가 구현되어 있어야 한다.
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 첫 번째 층
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 두 번째 층
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 전결합 층
self.fc = nn.Linear(7 * 7 * 64, 10, bias=True)
# 전결합 층 가중치 초기화
nn.init.xavier_uniform(self.fc.weight)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
# 평평하게 펼침
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = CNN()
model.load_state_dict(
torch.load(f='model_parameters.pth')
)
모델 자체를 저장
모델 자체를 저장하면 된다.
# 모델 자체를 저장
torch.save(
obj=model,
f='cnn_model.pth'
)
모델 자체를 저장했기 때문에 파라미터만 저장했을 때와 달리 이 모델을 불러올 때는 클래스가 없어도 된다.
model = torch.load(f='cnn_model.pth')
손글씨 이미지 숫자 예측하기
이 부분부터는 opencv를 사용한다.
혹시 opencv를 잘 모른다면 아래 포스팅을 보고 오길 바란다.
https://growth-coder.tistory.com/239
이제 본격적으로 이미지 숫자를 예측해보자.
우선 모델을 생성하고 파라미터를 불러오자.
import torch
import torch.nn as nn
import cv2
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 첫 번째 층
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 두 번째 층
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 전결합 층
self.fc = nn.Linear(7 * 7 * 64, 10, bias=True)
# 전결합 층 가중치 초기화
nn.init.xavier_uniform(self.fc.weight)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
# 평평하게 펼침
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = CNN()
model.load_state_dict(
torch.load(f='cnn_parameters.pth')
)
그리고 이미지를 불러오자. 나는 아래처럼 이진화가 되어있는 28 * 28 이미지를 사용했다.
혹시 테스트하고 싶다면 아래 이미지 파일을 받아서 사용하면 좋을 것 같다.
우선 0을 테스트 해보자. cv2를 사용하여 이미지를 불러오고 shape와 값 타입을 확인해보자.
img = cv2.imread("1.png")
print(img.shape)
print(type(img[0][0][0]))
이진화 처리가 되어있는 이미지임에도 channel이 3인 모습을 확인할 수 있고 저장되어있는 데이터는 unit8 형식이다.
이제 이 데이터를 학습한 데이터와 동일한 형식으로 바꿔줘야 한다.
학습 시킬 때 X의 shape와 요소 타입을 출력해보면 다음과 같다.
즉 (batch size, channel, height, width) shape를 갖도록 바꿔주고 numpy 배열을 tensor로 바꿔준 후 uint8이 아닌 float32 타입을 갖도록 바꿔줘야 한다.
다음 과정을 통해 입력 가능한 형태로 바꿔주고 shape과 값 타입을 출력해보자.
# 채털 수가 3개라서 gray scale로 변경
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, (28, 28))
# 이미지를 [0, 1] 범위로 스케일링
img = img/255
# 이미지를 텐서로 변경
img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0
# 입력 데이터의 shape과 타입을 출력
print(img_tensor.shape)
print(type(img_tensor[0][0][0][0]))
이제 입력 가능한 형식이 되었다. 생성한 모델 인스턴스에 이 값을 넣어보자.
res = model(img_tensor)
print(res.argmax().item())
우리는 0에서 9까지 분류를 해야하기 때문에 res의 값은 아래와 같이 확률을 요소로 가지는 길이가 10인 리스트가 반환된다.
이제 여기서 가장 값이 큰 인덱스를 구하면 이 값이 곧 예측한 값이 된다.
우리는 1 이미지를 넣었고 인덱스 1의 값이 가장 큰 것을 확인할 수 있다.
argmax 메소드를 통해서 가장 큰 값의 인덱스를 출력하면 그 값이 예측한 값이다.
0과 1을 각각 불러와서 전처리 후 모델에 넣어서 잘 예측하는지 확인하자.
웹 캠을 통해 손글씨 숫자 인식하기
이제 웹 캠을 통해서 손글씨 숫자를 인식해보자.
다음은 웹 캠 설정이다.
# 웹 캠 설정
webcam = cv2.VideoCapture(0)
webcam.set(3, 640) # width 세팅
webcam.set(4, 480) # height 세팅
VideoCapture에 0을 넣으면 기본 캠으로 설정한다.
그리고 width, height를 각각 설정하자.
다음은 캠으로 프레임을 읽어 띄우는 코드이다.
while True:
# success는 성공 여부, img는 이미지
success, img = webcam.read()
print(success)
cv2.imshow("cam", img)
# q를 누르면 무한 반복에서 빠져나옴
if cv2.waitKey(1)&0xFF == ord('q'):
break
read를 사용하면 성공 여부와 이미지를 함께 반환해준다.
무한 반복문으로 한 프레임씩 읽어서 띄워주기 때문에 q를 누르면 반복문에서 빠져나오는 코드도 작성해준다.
벽에 붙여둔 손글씨가 보이는 모습을 확인할 수 있다.
지금부터는 opencv를 사용할 예정이다.
사용법을 모른다면 아래 링크를 참고하는 것을 추천한다.
https://growth-coder.tistory.com/239
https://growth-coder.tistory.com/243
https://growth-coder.tistory.com/244
가장 먼저 도형의 mask를 추출해보자.
위 링크 중 두 번째 링크에서 원하는 색상의 mask를 얻어내는 방법에 대해 배웠다.
그 링크에서는 색상을 감지하기 위해 HSV 색 공간으로 변환 후 범위를 통해 mask를 얻어냈는데 이번에는 threshold 값을 지정하는 방식을 사용하려고 한다.
threshold 방식을 사용하기 위해서는 먼저 이미지를 grayscale로 변경한다.
그리고 cv2.threshole 메소드를 통해 원하는 임계점을 지정하고 이 임계점을 넘어가면 0(검은색)으로 만들고 넘지 못하면 최대값(흰 색)으로 만든다.
retval, thr = cv2.threshold(이미지, 임계값, 임계값을 넘으면 바꿀 값, 옵션)
옵션 종류
1. THRESH_BINARY: 임계 값보다 크면 위에서 지정한 바꿀 값으로 변경 작으면 0으로 변경
2. THRESH_BINARY_INV: 임계값보다 크면 0으로 변경 작으면 위에서 지정한 바꿀 값으로 변경
등등..
두 개의 반환 값 중 두 번째 값이 이진화 처리를 진행한 이미지
우선 어느 정도의 값을 threshold로 정해야 mask를 얻어낼 수 있는지 tracking bar를 생성해서 확인해보자.
# 윈도우 생성
cv2.namedWindow("track bar")
# 윈도우 크기 지정
cv2.resizeWindow("track bar", 640, 480)
# 트랙바가 바뀔 때마다 실행될 함수
def empty(a):
pass
# (트랙바 이름, 윈도우 이름, 첫 값, 최대 값, 함수)
cv2.createTrackbar("threshold", "track bar", 0, 255, empty)
while True:
# success는 성공 여부, img는 이미지
success, img = webcam.read()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
threshold = cv2.getTrackbarPos("threshold", "track bar")
_, mask = cv2.threshold(img, threshold, 255, cv2.THRESH_BINARY_INV)
cv2.imshow("cam", mask)
# q를 누르면 무한 반복에서 빠져나옴
if cv2.waitKey(1)&0xFF == ord('q'):
break
나는 110정도가 숫자 mask를 적절하게 얻어낼 수 있었다.
이제 tracking bar 코드는 없애자.
숫자 mask를 얻어냈다면 contour 정보를 얻어내서 bounding rectangle을 원본 이미지에 그려보자.
숫자를 잘 인식하고 있다.
이제 우리는 모델에 넣을 수 있게 이미지를 변형해야 한다.
이미지를 torch.tensor([1, 1, 28, 28])로 변형해보자.
while True:
# success는 성공 여부, img는 이미지
success, img = webcam.read()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(img, 110, 255, cv2.THRESH_BINARY_INV)
cv2.imshow("cam mask", mask)
# contour 얻기
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
# bounding rectangle 좌표 얻기
x, y, w, h = cv2.boundingRect(cnt)
cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 255), thickness=2)
cv2.imshow("cam", img)
# q를 누르면 무한 반복에서 빠져나옴
if cv2.waitKey(1)&0xFF == ord('q'):
break
우선 정확히 숫자 이미지만 얻을 수 있게 이미지를 잘라보자.
mask[ y : y + h, x : x + w ]와 같이 이미지를 자르면 숫자 이미지만 얻어낼 수 있다.
참고로 원본 이미지가 아닌 mask에서 잘라야 한다. 결국 들어가는 이미지는 이진화 처리된 이미지이기 때문이다.
이미지만 잘라냈다면 해당 이미지를 정사각형으로 만들어야 한다.
width와 height 중 작은 값에 제로 패딩을 추가하면 된다.
def make_img_square(img):
"""
이진화된 이미지를 받아서 검은색 padding을 넣어서 정사각형으로 만들어서 반환
:param img: 이진화된 정사각형이 아닌 이미지
:return:
"""
h, w = img.shape
# 높이와 너비 중 큰 값을 찾아 정사각형의 한 변의 길이로 설정
max_dim = max(h, w)
# 패딩 값 계산
top_pad = (max_dim - h) // 2
bottom_pad = max_dim - h - top_pad
left_pad = (max_dim - w) // 2
right_pad = max_dim - w - left_pad
# 검은색으로 패딩 추가
square_img = cv2.copyMakeBorder(img, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT,
value=[0, 0, 0])
return square_img
정사각형으로 만들었다면 28*28 사이즈로 변경 후 torch.tensor([1, 1, 28, 28])로 변형 후 모델에 넣어주면 된다.
그리고 한 번 예측 값을 원본 이미지에 띄워보자.
<최종 코드>
import torch
import torch.nn as nn
import cv2
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 첫 번째 층
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 두 번째 층
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 전결합 층
self.fc = nn.Linear(7 * 7 * 64, 10, bias=True)
# 전결합 층 가중치 초기화
nn.init.xavier_uniform(self.fc.weight)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
# 평평하게 펼침
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = CNN()
model.load_state_dict(
torch.load(f='cnn_parameters.pth')
)
# 웹 캠 설정
webcam = cv2.VideoCapture(0)
webcam.set(3, 640) # width 세팅
webcam.set(4, 480) # height 세팅
def make_img_square(img):
"""
이진화된 이미지를 받아서 검은색 padding을 넣어서 정사각형으로 만들어서 반환
:param img: 이진화된 정사각형이 아닌 이미지
:return:
"""
h, w = img.shape
# 높이와 너비 중 큰 값을 찾아 정사각형의 한 변의 길이로 설정
max_dim = max(h, w)
# 패딩 값 계산
top_pad = (max_dim - h) // 2
bottom_pad = max_dim - h - top_pad
left_pad = (max_dim - w) // 2
right_pad = max_dim - w - left_pad
# 검은색으로 패딩 추가
square_img = cv2.copyMakeBorder(img, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT,
value=[0, 0, 0])
return square_img
while True:
# success는 성공 여부, img는 이미지
success, img = webcam.read()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(img, 110, 255, cv2.THRESH_BINARY_INV)
cv2.imshow("cam mask", mask)
# contour 얻기
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
# bounding rectangle 좌표 얻기
x, y, w, h = cv2.boundingRect(cnt)
cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 255), thickness=2)
# 숫자 이미지만 자르기
cropped_img = mask[y:y+h, x:x+w]
# 정사각형으로 만들기
square_img = make_img_square(cropped_img)
square_img = cv2.resize(square_img, (28, 28,))
cv2.imshow("square", square_img)
# 텐서로 변경
# 이미지를 [0, 1] 범위로 스케일링
square_img = square_img / 255
# 이미지를 텐서로 변경
img_tensor = torch.tensor(square_img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# 예측
predicted = model(img_tensor)
res = predicted.argmax().item()
# 원본 이미지에 예측 값 그리기
cv2.putText(img, str(res), (x+w//2, y+10), cv2.FONT_ITALIC, 1.5, (0, 255, 255), thickness=1)
cv2.imshow("cam", img)
# q를 누르면 무한 반복에서 빠져나옴
if cv2.waitKey(1)&0xFF == ord('q'):
break
<적절한 임계 값을 찾는 코드>
import cv2
# 웹 캠 설정
webcam = cv2.VideoCapture(0)
webcam.set(3, 640) # width 세팅
webcam.set(4, 480) # height 세팅
# 윈도우 생성
cv2.namedWindow("track bar")
# 윈도우 크기 지정
cv2.resizeWindow("track bar", 640, 480)
# 트랙바가 바뀔 때마다 실행될 함수
def empty(a):
pass
# (트랙바 이름, 윈도우 이름, 첫 값, 최대 값, 함수)
cv2.createTrackbar("threshold", "track bar", 0, 255, empty)
while True:
# success는 성공 여부, img는 이미지
success, img = webcam.read()
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
threshold = cv2.getTrackbarPos("threshold", "track bar")
_, mask = cv2.threshold(img, threshold, 255, cv2.THRESH_BINARY_INV)
cv2.imshow("cam mask", mask)
이렇게 웹 캠을 통해 실시간으로 숫자 이미지를 예측하는 과정을 진행하였다.
간단하게 해보았지만 한계점이 존재한다.
먼저 단순하게 gray scale로 변경 후 threshold를 처리했기 때문에 흰 배경이 아닌 어두운 배경에서는 다른 객체를 인식할 가능성이 높다.
또한 조명, 각도, 환경에 따라 threshold 값이 달라질 수도 있다.
본인의 환경에 맞게끔 적절하게 이미지를 처리하는 과정을 추가하는 것이 좋다.
'공부 > AI' 카테고리의 다른 글
[LangChain] LangChain 개념 및 사용법 (0) | 2023.10.18 |
---|---|
[PyTorch] 긍정 리뷰, 부정 리뷰 분류하기 (3) - 모델 변경 (GRU) (0) | 2023.09.19 |
[PyTorch] 긍정 리뷰, 부정 리뷰 분류하기 (2) - 구현 (임베딩, RNN) (0) | 2023.09.18 |
[PyTorch] 긍정 리뷰, 부정 리뷰 분류하기 (1) - 개념 (0) | 2023.09.15 |
[PyTorch] MNIST로 학습한 CNN 모델로 웹 캠 손 글씨 숫자 인식하기 (1) (0) | 2023.09.09 |
댓글