Skip to content

k-NN 알고리즘 #
Find similar titles

Structured data

Category
Algorithm

k-최근접 이웃 알고리즘 (k-NN) #

k-최근접 이웃 알고리즘(또는 줄여서 k-NN)은 분류나 회귀에 사용되는 알고리즘이다.

두 경우 모두 입력이 특징 공간 내 k개의 가장 가까운 훈련 데이터로 구성되어 있다. 출력은 k-NN이 분류로 사용되었는지 또는 회귀로 사용되었는지에 따라 다르다.

Image

이미지 출처 : https://ko.wikipedia.org/wiki/K-%EC%B5%9C%EA%B7%BC%EC%A0%91_%EC%9D%B4%EC%9B%83_%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98

훈련 데이터는 다차원 특징 공간에서의 벡터이다. 알고리즘의 훈련 단계는 오직 훈련 표본의 특징 벡터와 분류명을 저장하는 것이다.

분류 단계에서 k는 사용자 정의 상수이고 분류명이 붙지 않은 벡터 즉, 타겟 데이터는 k개의 훈련 표본 사이에서 가장 빈번한 분류명을 할당함으로써 분류된다.

연속 변수에서 가장 흔하게 사용되는 거리 척도는 유클리드 거리이다. 이산 변수의 경우 중첩 거리와 같은 다른 척도가 사용될 수 있다.

k-NN 예제 #

목표는 특정 식료품이 어떤 카테고리에 속하는지 구하는 것이고, 분류 알고리즘은 k-NN을 사용할 것이다. 분류를 위한 특징 공간은 아삭거림과 단맛이고, 카테고리는 과일, 단백질, 채소이다.

데이터셋 #

아삭거림과 단맛을 벡터화하여 계산할 수 있다고 가정하여 데이터셋을 만들었다. 아삭거림 점수가 높으면 높을수록, 아삭거림이 심하다는 것이고, 마찬가지로 단맛의 점수가 높으면 높을수록, 단맛이 심하다는 뜻이다.

dataset = [
    # ([아삭거림, 단맛], 라벨)
    ([8, 5], "과일"), # 포도
    ([2, 3], "단백질"), # 생선
    ([7, 10], "채소"), # 당근
    ([7, 3], "과일"), # 오렌지
    ([3, 8], "채소"), # 셀러리
    ([1, 1], "단백질") # 치즈
]

거리 계산 정의 #

상술했듯이, k-NN은 데이터들 사이의 거리를 계산하여 분류하는 알고리즘이기 때문에, 거리 계산을 정의해야 한다. 유클리드 거리 계산법을 사용할 것이다.

def calculate_distance(v1, v2):
    cal_diff = np.array(v1) - np.array(v2)
    sq_diff = cal_diff ** 2
    row_sum = sq_diff.sum()
    distance = np.sqrt(row_sum)
    return distance

k-NN 알고리즘 정의 #

k-NN 알고리즘의 구현은 비교적 간단하다.

  • 타겟이 되는 데이터와 모든 데이터셋에 대한 거리를 구한 다음
  • 거리들을 정렬하고
  • 가장 거리가 짧은 k개의 상위 데이터를 구하여
  • 데이터들 안의 라벨들을 카운팅하면 된다.

이를 코드로 나타내면 아래와 같다.

def calculate_knn(k, dataset, target):

    distance = []

    for data, _ in dataset:
        distance.append(
           calculate_distance(data, target)) # 모든 거리를 계산

    distance_sorted_index = np.array(distance).argsort() # 계산한 거리들을 정렬

    result = {}

    for i in range(k): # k개의 상위 거리에 대하여
        _, label = dataset[distance_sorted_index[i]] # 해당 거리에 해당하는 데이터의 라벨을 가져와서
        if label in result:
            result[label] += 1 # 라벨 카운팅
        else:
            result[label] = 1

    return result

결과 #

아삭거림과 단맛이 각각 7, 8 인 미지의 타겟 데이터에 대하여, 이 데이터를 현재 데이터셋과 비교하여 어떤 카테고리에 속하는지 판단할 수 있다. 아래 결과는, 타겟 데이터의 k 범위 내에 카테고리가 과일인 데이터가 1개 있고, 채소인 데이터가 1개 있다는 뜻이다.

calculate_knn(2, dataset, [7, 8]) # {'과일': 1, '채소': 1}

타겟 데이터의 범위인 k 값을 줄여 좀 더 엄밀하게 분류할 수도 있다.

calculate_knn(1, dataset, [7, 8]) # {'채소': 1}

참고 사이트 #

0.0.1_20210630_7_v33