## Defining a data point
We define what a data point is for the purpose of testing our functions as we build them. I've only tested with a few `print` statements, since most of these are simple functions with not many moving parts. Usually, the correct thing to do is to *write tests before* instead of *printing after*. Anyway, here's a data point:
```python
# Number is used instead of id and label instead of class,
# since those are both reserved key words in Python.
class DataPoint(object):
def __init__(self, number, attributes, label):
self.number = number
self.attributes = attributes
self.label = label
def __repr__(self):
return '{} ({}): {}'.format(self.number, self.label, self.attributes)
def __hash__(self):
return self.number
```
```python
print(DataPoint(1, {'first': 1, 'second': 'A', 'third': True}, 'positive'))
```
1 (positive): {'third': True, 'second': 'A', 'first': 1}
## Info function
The info function tells us the entropy of the set. It is maximum for sets with more than one equiprobable label and minumum for homogeneous sets.
```python
from collections import defaultdict
from math import log
def info(D):
counts = defaultdict(int)
for d in D:
counts[d.label] += 1
N = len(D)
if N == 0:
raise Exception('No data points given!')
probabilities = [float(counts[c])/N for c in counts]
return -sum(p*log(p, 2) for p in probabilities)
```
```python
print(info({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'B')})) # Should be 1
print(info({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'A')})) # Should be 0
print(info({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'A'), DataPoint(3, {}, 'B')})) # Should be ~0.918
try:
info({})
except Exception:
print('OK')
```
1.0
-0.0
0.9182958340544896
OK
## Gain function of an attribute
We say an attribute has an information gain equal to the average homogeneity (info) of the sets into which it divides our data. *Phew*. That means that the gain of an attribute is maximum if splitting the data based on that attribute yields the most homogeneous subgroups. This might have a downside for attributes with many, many values, since the subgroups will be too small and consequently very homogeneous, but the tree won't generalize.
```python
def gain(D, attribute):
subgroups = defaultdict(set)
for d in D:
subgroups[d.attributes[attribute]].add(d)
N = len(D)
if N == 0:
raise Exception('No data points given!')
attribute_info = sum(len(g)*info(g) for _, g in subgroups.items())/N
return info(D) - attribute_info
```
```python
print(gain({DataPoint(1, {'a': 1}, 'A'), DataPoint(2, {'a': 0}, 'B')}, 'a')) # Should be 1
print(gain({DataPoint(1, {'a': 1}, 'A')}, 'a')) # Should be 0
try:
gain({}, 'a')
except Exception:
print('OK')
```
1.0
-0.0
OK
## Auxiliary functions
### Check if a data set is homogeneous
If the data set has only one label, this function returns that label and `True`. It returns `None` and `False` otherwise.
```python
def is_homogenous(D):
labels = {d.label for d in D}
N = len(labels)
if N == 0:
raise Exception("No data points given!")
elif N == 1:
return labels.pop(), True
return None, False
```
```python
print(is_homogenous({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'B')})) # None, False
print(is_homogenous({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'A')})) # 'A', True
try:
print(is_homogenous({}))
except Exception:
print('OK')
```
(None, False)
('A', True)
OK
### Return the most common label in a data set
This (duh) returns the most common label within a data set.
```python
from collections import Counter
def most_common_label(D):
if len(D) == 0:
raise Exception('No data points given!')
counter = Counter([d.label for d in D])
return counter.most_common(1)[0][0]
```
```python
print(most_common_label({DataPoint(1, {}, 'A')})) # 'A'
print(most_common_label({DataPoint(1, {}, 'A'), DataPoint(2, {}, 'B'), DataPoint(3, {}, 'B')})) # 'B'
try:
print(most_common_label({}))
except Exception:
print('OK')
```
A
B
OK
## The decision tree class
I'll leave this as a data class. But, in truth, I could have the next function (`build_tree()`) as a `train()` method of the `DecisionNode` class. `is_leaf()` is very self-evident. `classify()` will take a data point and send it to the right child for further inspection based on the node's attribute. Once the data point arrives on a leaf node, it receives that node's label.
```python
class DecisionNode(object):
def __init__(self):
self.label = None
self.attribute = None
self.children = {}
def is_leaf(self):
return len(self.children) == 0
def classify(self, data_point):
if self.is_leaf():
return self.label
else:
value = data_point.attributes[self.attribute]
return self.children[value].classify(data_point)
def __repr__(self):
return self.__repr2__()
def __repr2__(self, depth=0):
s = ''
if self.is_leaf():
s += 'label {}'.format(self.label)
else:
s += 'decide on {}'.format(self.attribute)
for c in self.children:
s += '\n{}{}: {}'.format('│ '*(depth+1), c, self.children[c].__repr2__(depth=depth+1))
return s
```
## Basic Algorithm
The main training algorithm. From Wikipedia:
```
ID3 (Examples, Target_Attribute, Attributes)
Create a root node for the tree
If all examples are positive, Return the single-node tree Root, with label = +.
If all examples are negative, Return the single-node tree Root, with label = -.
If number of predicting attributes is empty, then Return the single node tree Root,
with label = most common value of the target attribute in the examples.
Otherwise Begin
A ← The Attribute that best classifies examples.
Decision Tree attribute for Root = A.
For each possible value, vi, of A,
Add a new tree branch below Root, corresponding to the test A = vi.
Let Examples(vi) be the subset of examples that have the value vi for A
If Examples(vi) is empty
Then below this new branch add a leaf node with label = most common target value in the examples
Else below this new branch add the subtree ID3 (Examples(vi), Target_Attribute, Attributes – {A})
End
Return Root
```
The most important part below is the line
```python
A = max(L, key=attribute_gain)
```
The `key` argument is where most of the different versions of this tree building algorithm differ. Another flexible point is the next loop, when children are added to the rood node. Right now, each attribute value corresponds to one child. This must not be like that. We could add a child for every interval or set of values instead. This would require some small changes on the `attribute_gain()` and the `values` loop.
```python
def build_tree(D, L):
root = DecisionNode()
label, it_is = is_homogenous(D)
if it_is:
root.label = label
return root
if len(L) == 0:
root.label = most_common_label(D)
return root
def attribute_gain(a):
return gain(D, a)
A = max(L, key=attribute_gain)
root.attribute = A
values = {d.attributes[A] for d in D}
for v in values:
subgroup = {d for d in D if d.attributes[A] == v}
if len(subgroup) == 0:
child = DecisionNode()
child.label = most_common_label(D)
else:
child = build_tree(subgroup, L - {A})
root.children[v] = child
return root
```
```python
print(build_tree({
DataPoint(1, {'a': 1}, 'A'),
}, {'a'}), end='\n\n')
print(build_tree({
DataPoint(1, {'a': 1}, 'A'),
DataPoint(2, {'a': 1}, 'B'),
}, {'a'}), end='\n\n')
print(build_tree({
DataPoint(1, {'a': 1}, 'A'),
DataPoint(2, {'a': 1}, 'B'),
DataPoint(3, {'a': 1}, 'B'),
}, {}), end='\n\n')
print(build_tree({
DataPoint(1, {'a': 1, 'b': 'no'}, 'A'),
DataPoint(2, {'a': 1, 'b': 'yes'}, 'B'),
}, {'a', 'b'}), end='\n\n')
```
label A
decide on a
│ 1: label B
label B
decide on b
│ yes: label B
│ no: label A
## Classifying
```python
big_tree = build_tree({
DataPoint(1, {'a': 1, 'b': 'no', 'c': 0}, 'A'),
DataPoint(2, {'a': 1, 'b': 'yes', 'c': 1}, 'B'),
DataPoint(3, {'a': 1, 'b': 'yes', 'c': 1}, 'B'),
DataPoint(4, {'a': 0, 'b': 'no', 'c': 0}, 'B'),
DataPoint(5, {'a': 1, 'b': 'yes', 'c': 0}, 'B'),
DataPoint(6, {'a': 0, 'b': 'no', 'c': 2}, 'A'),
DataPoint(7, {'a': 1, 'b': 'yes', 'c': 2}, 'A'),
DataPoint(8, {'a': 0, 'b': 'no', 'c': 2}, 'A'),
DataPoint(9, {'a': 1, 'b': 'yes', 'c': 0}, 'B'),
DataPoint(10, {'a': 1, 'b': 'yes', 'c': 0}, 'B'),
}, {'a', 'b', 'c'})
print(big_tree, end='\n\n')
```
decide on c
│ 0: decide on b
│ │ yes: label B
│ │ no: decide on a
│ │ │ 0: label B
│ │ │ 1: label A
│ 1: label B
│ 2: label A
Have fun here...
```python
print(big_tree.classify(DataPoint(11, {'a': 1, 'b': 'yes', 'c': 0}, '?')))
```
B