Skip to content

Decision tree learning #
Find similar titles

Structured data

Category
Algorithm

Decision tree learning #

결정 트리 학습법(Decision tree learning)이란, 결정 트리를 사용하여 어떤 항목에 대한 관측값과 목표값을 연결시켜주는 예측 모델이다. 이는 통계학과 데이터 마이닝, 기계 학습에서 사용하는 예측 모델링 방법 중 하나이다. 트리 모델 중 목표 변수가 유한한 수의 값을 가지는 것을 분류 트리라 한다. 이 트리 구조에서 잎(leaf node)은 클래스 라벨을 나타내고 가지는 클래스 라벨과 관련있는 특징들의 논리곱을 나타낸다. 반대로, 결정 트리 중 목표 변수가 연속하는 값, 일반적으로 실수를 가지는 것은 회귀 트리라 한다.

의사 결정 분석에서 결정트리는 시각적이고 명시적인 방법으로 의사 결정 과정과 결정된 의사를 보여주는데 사용된다. 데이터 마이닝 분야에서 결정 트리는 결정된 의사보다는 자료 자체를 표현하는데 사용된다. 다만, 데이터 마이닝의 결과로서의 분류 트리는 의사 결정 분석의 입력 값으로 사용될 수 있다.

Decision tree 예제 #

데이터셋 #

목표는 해당 학습 데이터 세트로 특정 직원이 인터뷰를 잘 할 수 있을지 예측하는 것이다. 데이터는 직원의 직급과 사용하는 프로그래밍 언어, 트위터 계정 유무, 박사 학위 유무와, 라벨로서 과거 인터뷰 평가(잘했으면 True, 아니면 False) 내역이 주어져있다.

dataset = [
    ({'level': 'Senior', 'lang': 'Java', 'tweets': 'no', 'phd': 'no'}, False),
    ({'level': 'Senior', 'lang': 'Java', 'tweets': 'no', 'phd': 'yes'}, False),
    ({'level': 'Mid', 'lang': 'Python', 'tweets': 'no', 'phd': 'no'}, True),
    ({'level': 'Junior', 'lang': 'Python', 'tweets': 'no', 'phd': 'no'}, True),
    ({'level': 'Junior', 'lang': 'R', 'tweets': 'yes', 'phd': 'no'}, True),
    ({'level': 'Junior', 'lang': 'R', 'tweets': 'yes', 'phd': 'yes'}, False),
    ({'level': 'Mid', 'lang': 'R', 'tweets': 'yes', 'phd': 'yes'}, True),
    ({'level': 'Senior', 'lang': 'Python', 'tweets': 'no', 'phd': 'no'}, False),
    ({'level': 'Senior', 'lang': 'R', 'tweets': 'yes', 'phd': 'no'}, True),
    ({'level': 'Junior', 'lang': 'Python', 'tweets': 'yes', 'phd': 'no'}, True),
    ({'level': 'Senior', 'lang': 'Python', 'tweets': 'yes', 'phd': 'yes'}, True),
    ({'level': 'Mid', 'lang': 'Python', 'tweets': 'no', 'phd': 'yes'}, True),
    ({'level': 'Mid', 'lang': 'Java', 'tweets': 'yes', 'phd': 'no'}, True),
    ({'level': 'Junior', 'lang': 'Python', 'tweets': 'no', 'phd': 'yes'}, False)
]

셰넌 엔트로피 계산 정의 #

셰넌 엔트로피란, 어떤 상태가 얼마만큼의 정보를 담고 있는가를 측정한 것이다. 이는 정보량이라 표현할 수 있으며, 정보이론의 기초를 따르면 정보를 더 많이 알면 알수록 새롭게 알 수 있는 정보는 적어지므로, 엔트로피가 높을수록 새로운 정보일 가능성이 높다. 이는 다시말하면, 정보량은 상태의 불확실성과 관계가 있다. 이를 수식으로 표현하면 아래와 같다.

H(S) = -( p_1 * log_2(p1) ) ... -( p_n * log_2(p2) )

이를 코드화하면 아래와 같다.

def entropy(class_probabilities):
    return sum(-p * math.log(p, 2) for p in class_probabilities if p)

이 함수를 이용하여, 데이터셋에서 각 데이터가 특정 라벨에 속할 확률을 구할 수 있다.

def class_probabilities(dataset):
    '''학습데이터로 특정 라벨에 속할 확률 구하기: [False에 속할 확률, True에 속할 확률]'''
    labels = [label for _, label in dataset]
    total_count = len(labels) # 데이터셋 총 개수
    return [count / total_count for count in Counter(labels).values()]

이 코드에 데이터셋을 넣어서 결과를 확인하면, [0.35714285714, 0.64285714] 이다. 각각 False, True 라벨에 속할 확률이다.

데이터셋의 라벨에 대한 확률분포를 구했으니, 이를 이용하여 데이터셋의 엔트로피를 구할 수 있다.

def data_entropy(dataset):
    probabilities = class_probabilities(dataset)
    return entropy(probabilities)

이는 전체 데이터셋에 대한 엔트로피이다. 하지만, 결정 트리의 각 레벨을 넘기다보면 데이터셋이 분할되므로, 여러개로 분할된 각각의 데이터셋의 엔트로피를 조합하여 전체 데이터셋의 엔트로피를 구할 수 있어야 한다. 이는 각 엔트로피의 가중합으로 정의할 수 있다.

def partition_entropy(subsets):
    total_count = sum(len(subset) for subset in subsets) # q를 구하기 위한 전체 데이터셋 길이
    return sum(data_entropy(subset) * (len(subset) / total_count)
              for subset in subsets)

Decision tree 생성 알고리즘 #

ID3 알고리즘에 기반하면, 기본은 그리디 알고리즘이며, 이는 각 순간마다 최적의 선택을 한다는 뜻이다. 최적의 선택 기준은 위에서 설명한 엔트로피가 될 것이고, 모든 단계를 지나면 특정값이 나와야 하므로, 엔트로피가 작은쪽이 최적의 선택이라 할 수 있겠다. 결정트리의 기본적인 로직은 각 속성에 대하여 파티션을 나누고 단계를 진행하면 할 수록 엔트로피는 최소값이 되어 이는 결국 결과값이 되는 것이다.

이를 기반하여, 데이터셋을 속성에 따라 파티셔닝해야 하므로, 이를 코드화하면 아래와 같다.

def partition_by(inputs, attribute):
    groups = defaultdict(list)
    for input in inputs:
        key = input[0][attribute]
        groups[key].append(input)
    return groups

각 파티션에 의해 분할된 데이터셋들의 전체 엔트로피가 기준이 된다고 하였으므로, 이를 코드화하면 아래와 같다.

def partition_entropy_by(inputs, attribute):
    partitions = partition_by(inputs, attribute)
    return partition_entropy(partitions.values())

이를 이용하여, 각 파티션에 대한 엔트로피를 구하여 최소화되는 파티션을 구하고, 서브트리를 만들 수 있다. 먼저 전체 데이터셋에서 엔트로피가 최소화되는 속성을 구해야 한다.

for key in ['level', 'lang', 'tweets', 'phd']:
    print(key, partition_entropy_by(dataset, key))

# level 0.693536138896
# lang 0.860131712855
# tweets 0.788450457308
# phd 0.892158928262

level 속성으로 분할하면 엔트로피가 가장 낮아지므로, 결정트리의 루트 노드는 level 속성으로 분할해야 할 것이다. 분할된 데이터셋은 각각 Senior, Mid, Junior 셋이므로 이 세개의 셋에 대하여 각각 서브트리를 구해야한다.

먼저 Mid 셋을 보면, 모든 클래스의 값이 True 이므로, 서브트리를 만들 필요없이 True를 리턴한다.

Senior 셋을 보면, tweets로 분할한 데이터셋들의 엔트로피가 0이므로, tweets로 분할하여 서브트리를 만든다.

Junior 셋을 보면, phd로 분할하면 엔트로피가 0이므로, phd로 분할하여 서브트리를 만든다.

이를 도식화하면, 아래와 같다.

 level? -- Senior --> tweets? -- no --> False
        |                     |- yes --> True
        |- Mid --> True
        |- Junior --> phd? -- no --> True
                           |- yes --> False

이 과정을 자동화하면 다음과 같다.

첫째로, ID3 알고리즘을 구현하면, 아래와 같다.

 def build_tree_id3(inputs, split_candidates=None):
    if split_candidates is None:
        split_candidates = inputs[0][0].keys()

    num_inputs = len(inputs)
    num_trues = len([label for item, label in inputs if label])
    num_falses = num_inputs - num_trues

    if num_trues == 0:
        return False
    if num_falses == 0:
        return True

    if not split_candidates:
        return num_trues >= num_falses

    best_attribute = min(split_candidates, key=partial(partition_entropy_by, inputs))

    partitions = partition_by(inputs, best_attribute)
    new_candidates = [a for a in split_candidates
                     if a != best_attribute]

    subtrees = { attribute_value: build_tree_id3(subset, new_candidates)
        for attribute_value, subset in partitions.items() }

    subtrees[None] = num_trues > num_falses 
    return (best_attribute, subtrees)

이후, 분류기를 구현하면 다음과 같다.

  def classify(tree, input):
    if tree in [True, False]:
        return tree
    attribute, subtree_dict = tree
    subtree_key = input.get(attribute)
    if subtree_key not in subtree_dict:
        subtree_key = None
    subtree = subtree_dict[subtree_key]
    return classify(subtree, input)

결과 #

아래와 같은 새로운 직원 데이터에 대하여 이 직원이 인터뷰를 잘할 수 있을지 예측해보자.

# 새로운 데이터
data = {"level": "Junior", "lang": "Java", "tweets": "yes"}

# 학습데이터를 이용하여 결정트리 생성
tree = build_tree_id3(dataset)
# 실제로 분류해보기
classify(tree, data)
# 결과: True

참고 사이트 #

0.0.1_20140628_0