Tree - Graph Valid Tree
All diagrams presented herein are original creations, meticulously designed to enhance comprehension and recall. Crafting these aids required considerable effort, and I kindly request attribution if this content is reused elsewhere.
Difficulty : Easy
DFS, Map
Problem
Given n
nodes labeled from 0
to n - 1
and a list of undirected
edges (each edge is a pair of nodes), write a function to check whether these edges make up a valid tree. (A valid tree does not have loops and all nodes need to be connected).
Example 1:
1
2
Input: n = 5 edges = [[0, 1], [0, 2], [0, 3], [1, 4]]
Output: true.
Example 2:
1
2
Input: n = 5 edges = [[0, 1], [1, 2], [2, 3], [1, 3], [1, 4]]
Output: false.
Solution
There two conditions we need to check to make sure the graph is a valid tree.
- All nodes are connected - This can be validated by comparing
len(visited)==n
- No cycle exists - This can be validated when we run
dfs()
. If we find the current node is already visited then a cycle is detected.
High Level Explanation
- Very simple and basic DFS problem.
- Create the adjacency map
- Visit each node and its neighbors
- Keep track of the prev node since the graph is undirected.
The input will be n
(number of nodes) and the edges
(connected nodes). Let’s first start with the adjacency map for traversal.
1
2
3
4
adjacency_map = collections.defaultdict(list)
for n1, n2 in edges:
adj[n1].append(n2)
adj[n2].append(n1)
Initialize the visit
set
1
visited = set()
Now create the dfs()
function to traverse the graph. The dfs()
function takes the current node
and also the previous prev_node
, this is to make sure not to traverse backwards since the graph is undirected.
Check for cycle first. As stated earlier, if we find the current node is already visited then a cycle is detected. We return False
.
1
2
3
def dfs(node, prev_node):
if node in visited:
return False
Now, add the node to visited
set
1
visited.add(node)
Traverse the neighbors of the node
except the prev_node
. Return False
if the dfs()
returns False
. Finally return True
.
1
2
3
4
5
6
7
for neighbor in adjacency_map[node]:
if neighbor == prev_node:
continue
if not dfs(neighbor,node):
return False
return True
Finally call dfs()
and also validate if all the nodes have been visited. We are passing None
for the first time to make sure the first node visits all its neighbors. We can pass any other values such as -1
or n+1
etc.
1
return dfs(0, None) and n==len(visited)
Here is the annoyed version, some of variable names are different though.
Final Code
Here is the full code.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
# Definition for a Node.
class Node:
def __init__(self, val = 0, neighbors = None):
self.val = val
self.neighbors = neighbors if neighbors is not None else []
"""
def valid_tree(n, edges):
adjacency_map = collections.defaultdict(list)
for node1, node2 in edges:
adjacency_map[node1].append(node2)
adjacency_map[node2].append(node1)
visited = set()
def dfs(node, prev_node):
if node in visited :
return False
visited.add(node)
for neighbor in adjacency_map[node]:
if neighbor !=prev_node:
if not dfs(neighbor, node):
return False
return True
return dfs(0, None) and len(visited)==n