# Warm up: Minimum Height Trees **Python** ```python from collections import defaultdict def find_min_height_trees(n, edges): if n == 1: return [0] graph = defaultdict(set) for i, j in edges: graph[i].append(j) graph[j].append(i) # Fetch the leaves; the leaves are nodes with only one edge leaves = [key for key, val in graph.items() if len(val) == 1] # Keep removing leaves until there are only two or fewer nodes left while n > 2: n -= len(leaves) newLeaves = [] for i in leaves: j = graph[i].pop() graph[j].remove(i) if len(graph[j]) == 1: newLeaves.append(j) leaves = newLeaves return leaves ``` **Java** ```java class Solution { public List<Integer> findMinHeightTrees(int n, int[][] edges) { if (n == 0) { return new ArrayList<>(); } else if (n == 1) { List<Integer> ret = new ArrayList<>(); ret.add(0); return ret; } List<Integer>[] graph = this.buildGraph(n, edges); List<Integer> leaves = new ArrayList<>(); for (int i = 0; i < n; i++) { if (graph[i].size() == 1) { leaves.add(i); } } int count = n; // Keep removing leaves until there are only 1 or 2 nodes left List<Integer> newLeaves; while (count > 2) { int size = leaves.size(); newLeaves = new ArrayList<>(); count -= size; for (int i = 0; i < size; i++) { int leaf = leaves.get(i); for (int j = 0; j < graph[leaf].size(); j++) { int toRemove = graph[leaf].get(j); graph[toRemove].remove(Integer.valueOf(leaf)); if (graph[toRemove].size() == 1) newLeaves.add(toRemove); } } leaves = newLeaves; } return leaves; } private List<Integer>[] buildGraph(int n, int[][] edges) { List<Integer>[] graph = new ArrayList[n]; for (int i = 0; i < n; i++) { graph[i] = new ArrayList<>(); } for (int i = 0; i < edges.length; i++) { int v1 = edges[i][0]; int v2 = edges[i][1]; graph[v1].add(v2); graph[v2].add(v1); } return graph; } } ``` ###### tags: `Week 5-Graphs`