Defining a data point

We define what a data point is for the purpose of testing our functions as we build them. I've only tested with a few print statements, since most of these are simple functions with not many moving parts. Usually, the correct thing to do is to write tests before instead of printing after. Anyway, here's a data point:

# Number is used instead of id and label instead of class,
# since those are both reserved key words in Python.
class DataPoint(object):
    def __init__(self, number, attributes, label):
        self.number = number
        self.attributes = attributes
        self.label = label
    
    def __repr__(self):
        return '{} ({}): {}'.format(self.number, self.label, self.attributes)
    
    def __hash__(self):
        return self.number
print(DataPoint(1, {'first': 1, 'second': 'A', 'third': True}, 'positive'))
​​​​1 (positive): {'third': True, 'second': 'A', 'first': 1}

Info function

The info function tells us the entropy of the set. It is maximum for sets with more than one equiprobable label and minumum for homogeneous sets.

from collections import defaultdict
from math import log

def info(D):
    counts = defaultdict(int)
    for d in D:
        counts[d.label] += 1
        
    N = len(D)
    if N == 0:
        raise Exception('No data points given!')
    probabilities = [float(counts[c])/N for c in counts]
    return -sum(p*log(p, 2) for p in probabilities)
print(info({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'B')})) # Should be 1
print(info({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'A')})) # Should be 0
print(info({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'A'), DataPoint(3, {}, 'B')})) # Should be ~0.918
try:
    info({})
except Exception:
    print('OK')
​​​​1.0
​​​​-0.0
​​​​0.9182958340544896
​​​​OK

Gain function of an attribute

We say an attribute has an information gain equal to the average homogeneity (info) of the sets into which it divides our data. Phew. That means that the gain of an attribute is maximum if splitting the data based on that attribute yields the most homogeneous subgroups. This might have a downside for attributes with many, many values, since the subgroups will be too small and consequently very homogeneous, but the tree won't generalize.

def gain(D, attribute):
    subgroups = defaultdict(set)
    for d in D:
        subgroups[d.attributes[attribute]].add(d)
    
    N = len(D)
    if N == 0:
        raise Exception('No data points given!')
    attribute_info = sum(len(g)*info(g) for _, g in subgroups.items())/N
    return info(D) - attribute_info
print(gain({DataPoint(1, {'a': 1}, 'A'), DataPoint(2, {'a': 0}, 'B')}, 'a')) # Should be 1
print(gain({DataPoint(1, {'a': 1}, 'A')}, 'a')) # Should be 0
try:
    gain({}, 'a')
except Exception:
    print('OK')
​​​​1.0
​​​​-0.0
​​​​OK

Auxiliary functions

Check if a data set is homogeneous

If the data set has only one label, this function returns that label and True. It returns None and False otherwise.

def is_homogenous(D):
    labels = {d.label for d in D}
    
    N = len(labels)
    if N == 0:
        raise Exception("No data points given!")
    elif N == 1:
        return labels.pop(), True
    
    return None, False
print(is_homogenous({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'B')})) # None, False
print(is_homogenous({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'A')})) # 'A', True
try:
    print(is_homogenous({}))
except Exception:
    print('OK')
​​​​(None, False)
​​​​('A', True)
​​​​OK

Return the most common label in a data set

This (duh) returns the most common label within a data set.

from collections import Counter

def most_common_label(D):
    if len(D) == 0:
        raise Exception('No data points given!')
    counter = Counter([d.label for d in D])
    return counter.most_common(1)[0][0]
print(most_common_label({DataPoint(1, {}, 'A')})) # 'A'
print(most_common_label({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'B'), DataPoint(3, {}, 'B')})) # 'B'
try:
    print(most_common_label({}))
except Exception:
    print('OK')
​​​​A
​​​​B
​​​​OK

The decision tree class

I'll leave this as a data class. But, in truth, I could have the next function (build_tree()) as a train() method of the DecisionNode class. is_leaf() is very self-evident. classify() will take a data point and send it to the right child for further inspection based on the node's attribute. Once the data point arrives on a leaf node, it receives that node's label.

class DecisionNode(object):
    def __init__(self):
        self.label = None
        self.attribute = None
        self.children = {}
    
    def is_leaf(self):
        return len(self.children) == 0
    
    def classify(self, data_point):
        if self.is_leaf():
            return self.label
        else:
            value = data_point.attributes[self.attribute]
            return self.children[value].classify(data_point)
    
    def __repr__(self):
        return self.__repr2__()
    
    def __repr2__(self, depth=0):
        s = ''
        if self.is_leaf():
            s += 'label {}'.format(self.label)
        else:
            s += 'decide on {}'.format(self.attribute)
        
        for c in self.children:
            s += '\n{}{}: {}'.format('β”‚ '*(depth+1), c, self.children[c].__repr2__(depth=depth+1)) 
                
        return s

Basic Algorithm

The main training algorithm. From Wikipedia:

ID3 (Examples, Target_Attribute, Attributes)
    Create a root node for the tree
    If all examples are positive, Return the single-node tree Root, with label = +.
    If all examples are negative, Return the single-node tree Root, with label = -.
    If number of predicting attributes is empty, then Return the single node tree Root,
    with label = most common value of the target attribute in the examples.
    Otherwise Begin
        A ← The Attribute that best classifies examples.
        Decision Tree attribute for Root = A.
        For each possible value, vi, of A,
            Add a new tree branch below Root, corresponding to the test A = vi.
            Let Examples(vi) be the subset of examples that have the value vi for A
            If Examples(vi) is empty
                Then below this new branch add a leaf node with label = most common target value in the examples
            Else below this new branch add the subtree ID3 (Examples(vi), Target_Attribute, Attributes – {A})
    End
    Return Root

The most important part below is the line

A = max(L, key=attribute_gain)

The key argument is where most of the different versions of this tree building algorithm differ. Another flexible point is the next loop, when children are added to the rood node. Right now, each attribute value corresponds to one child. This must not be like that. We could add a child for every interval or set of values instead. This would require some small changes on the attribute_gain() and the values loop.

def build_tree(D, L):
    root = DecisionNode()
    
    label, it_is = is_homogenous(D)
    if it_is:
        root.label = label
        return root
    
    if len(L) == 0:
        root.label = most_common_label(D)
        return root
    
    def attribute_gain(a):
        return gain(D, a)
    
    A = max(L, key=attribute_gain)
    root.attribute = A
    
    values = {d.attributes[A] for d in D}
    for v in values:
        subgroup = {d for d in D if d.attributes[A] == v}
        
        if len(subgroup) == 0:
            child = DecisionNode()
            child.label = most_common_label(D)
        else:
            child = build_tree(subgroup, L - {A})
        
        root.children[v] = child
    
    return root
print(build_tree({
    DataPoint(1, {'a': 1}, 'A'),
}, {'a'}), end='\n\n')

print(build_tree({
    DataPoint(1, {'a': 1}, 'A'),
    DataPoint(2, {'a': 1}, 'B'),
}, {'a'}), end='\n\n')

print(build_tree({
    DataPoint(1, {'a': 1}, 'A'),
    DataPoint(2, {'a': 1}, 'B'),
    DataPoint(3, {'a': 1}, 'B'),
}, {}), end='\n\n')

print(build_tree({
    DataPoint(1, {'a': 1, 'b': 'no'}, 'A'),
    DataPoint(2, {'a': 1, 'b': 'yes'}, 'B'),
}, {'a', 'b'}), end='\n\n')
​​​​label A
​​​​
​​​​decide on a
​​​​│ 1: label B
​​​​
​​​​label B
​​​​
​​​​decide on b
​​​​│ yes: label B
​​​​│ no: label A

Classifying

big_tree = build_tree({
    DataPoint(1, {'a': 1, 'b': 'no', 'c': 0}, 'A'),
    DataPoint(2, {'a': 1, 'b': 'yes', 'c': 1}, 'B'),
    DataPoint(3, {'a': 1, 'b': 'yes', 'c': 1}, 'B'),
    DataPoint(4, {'a': 0, 'b': 'no', 'c': 0}, 'B'),
    DataPoint(5, {'a': 1, 'b': 'yes', 'c': 0}, 'B'),
    DataPoint(6, {'a': 0, 'b': 'no', 'c': 2}, 'A'),
    DataPoint(7, {'a': 1, 'b': 'yes', 'c': 2}, 'A'),
    DataPoint(8, {'a': 0, 'b': 'no', 'c': 2}, 'A'),
    DataPoint(9, {'a': 1, 'b': 'yes', 'c': 0}, 'B'),
    DataPoint(10, {'a': 1, 'b': 'yes', 'c': 0}, 'B'),
}, {'a', 'b', 'c'})
print(big_tree, end='\n\n')
​​​​decide on c
​​​​│ 0: decide on b
​​​​│ β”‚ yes: label B
​​​​│ β”‚ no: decide on a
​​​​│ β”‚ β”‚ 0: label B
​​​​│ β”‚ β”‚ 1: label A
​​​​│ 1: label B
​​​​│ 2: label A

Have fun here…

print(big_tree.classify(DataPoint(11, {'a': 1, 'b': 'yes', 'c': 0}, '?')))
​​​​B