코딩뚠뚠

[머신러닝 공부] KNN의 이해 본문

공부/ML&DL

[머신러닝 공부] KNN의 이해

로디네로 2021. 4. 4. 19:53
반응형

2019년 공모전당시 KNN을 이용해 OCR기능을 구현한 적이 있다.

 

당시엔 영상, 머신러닝 분야에 모두 익숙하지 않아 공부하지 않고 코드를 가져와 사용하기만 했기에 이번 기회에 정리해보고자 한다.

 


 

KNN은 지도학습 (Supervised Learning) 을 통한 가장 단순한 분류의 알고리즘 중 하나 이다.

 

아이디어는 새로운 데이터를 공간 상의 가장 가까운 것들과 묶는다는 것이다.

 

빨간색 input data와 가장 가까운 이웃을 찾아갈 때 

 

이웃을 최대 세 명까지만 본다고 하자 (k=3) 그렇다면 그 중에는 보라색 Class B가 더 많기 때문에 빨간색은 Class B로 분류할 수 있다.

 

이웃을 여섯명 까지 본다고 하면 (k=6) 그 중에는 노란색인 Class A가 더 많기 때문에 빨간색은 Class A로 분류된다.

 

여기서 가까운 것에 가중치를 준 것이 Modified KNN 이다.

 

 


 

간단한 예제

 

1. 난수를 사용하여 (0~100) Class 1 과 Class 2 를 x-y 좌표계에 생성시킨다.

 

2. 난수로 새로운 점을 생성하고 이 점이 어떤 class를 가지는지 예측해본다.

 

3. 만약 한개의 새로운 점이 아닌 10개의 요소를 생성하고 예측하기 위해서는 3)과 같이 하면 된다. (배열로전달)

1)번과정

import cv2
import numpy as np
import matplotlib.pyplot as plt
trainData = np.random.randint(0,100,(25,2)).astype(np.float32)
responses = np.random.randint(0,2,(25,1)).astype(np.float32)
red = trainData[responses.ravel()==0]
plt.scatter(red[:,0],red[:,1],80,'r','^')
blue = trainData[responses.ravel()==1]
plt.scatter(blue[:,0],blue[:,1],80,'b','s')
plt.show()

2)번과정

newcomer = np.random.randint(0,100,(1,2)).astype(np.float32)
plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o')
knn = cv2.ml.KNearest_create()
knn.train(trainData, cv2.ml.ROW_SAMPLE, responses)
ret, results, neighbours ,dist = knn.findNearest(newcomer, 3)
print("result: ", results)
print("neighbours: ", neighbours)
print("distance: ", dist)

3)번과정

newcomers = np.random.randint(0,100,(10,2)).astype(np.float32)
plt.scatter(newcomers[:,0],newcomers[:,1],80,'g','o')
knn = cv2.ml.KNearest_create()
knn.train(trainData, cv2.ml.ROW_SAMPLE, responses)
ret, results, neighbours ,dist = knn.findNearest(newcomers, 3)
print("result: ", results)
print("neighbours: ", neighbours)
print("distance: ", dist)

 

 


 

마치며

 

사실 혼동하던 용어가 있었다.

 

kNN 을 이용해 학습 시켰다고는 할 수가 없다.

 

학습을 해서 모델이 발전하는 것이 아닌 그저 라벨을 부여해 줄 뿐이다.

 

kNN을 쓰면서 train 함수를 사용하게 되는데 학습시킨다고 혼동하지 말자. 사실은 학습이 아니니..

반응형