# Warm up: Minimum Height Trees
**Python**
```python
from collections import defaultdict
def find_min_height_trees(n, edges):
if n == 1:
return [0]
graph = defaultdict(set)
for i, j in edges:
graph[i].append(j)
graph[j].append(i)
# Fetch the leaves; the leaves are nodes with only one edge
leaves = [key for key, val in graph.items() if len(val) == 1]
# Keep removing leaves until there are only two or fewer nodes left
while n > 2:
n -= len(leaves)
newLeaves = []
for i in leaves:
j = graph[i].pop()
graph[j].remove(i)
if len(graph[j]) == 1:
newLeaves.append(j)
leaves = newLeaves
return leaves
```
**Java**
```java
class Solution {
public List<Integer> findMinHeightTrees(int n, int[][] edges) {
if (n == 0) {
return new ArrayList<>();
} else if (n == 1) {
List<Integer> ret = new ArrayList<>();
ret.add(0);
return ret;
}
List<Integer>[] graph = this.buildGraph(n, edges);
List<Integer> leaves = new ArrayList<>();
for (int i = 0; i < n; i++) {
if (graph[i].size() == 1) {
leaves.add(i);
}
}
int count = n;
// Keep removing leaves until there are only 1 or 2 nodes left
List<Integer> newLeaves;
while (count > 2) {
int size = leaves.size();
newLeaves = new ArrayList<>();
count -= size;
for (int i = 0; i < size; i++) {
int leaf = leaves.get(i);
for (int j = 0; j < graph[leaf].size(); j++) {
int toRemove = graph[leaf].get(j);
graph[toRemove].remove(Integer.valueOf(leaf));
if (graph[toRemove].size() == 1)
newLeaves.add(toRemove);
}
}
leaves = newLeaves;
}
return leaves;
}
private List<Integer>[] buildGraph(int n, int[][] edges) {
List<Integer>[] graph = new ArrayList[n];
for (int i = 0; i < n; i++) {
graph[i] = new ArrayList<>();
}
for (int i = 0; i < edges.length; i++) {
int v1 = edges[i][0];
int v2 = edges[i][1];
graph[v1].add(v2);
graph[v2].add(v1);
}
return graph;
}
}
```
###### tags: `Week 5-Graphs`