In a previous article, we covered the intuition behind the Aho-Corasick string matching algorithm. Now, I will explain its implementation in Python. There are a variety of ways to do this. I will explain the way which uses an adjacency list to store the trie. Arguably, it is much cleaner to use classes and objects, but I decided not to do that because that's way overdone, and I'm a hipster (if you want that version, you can email me for it – it's quite easy to change it to use classes).
Recall from last time that we needed to construct the trie, and then set its failure transitions. After the trie is constructed, we traverse the trie as we are reading in the input text and output the positions that we find the keywords at. Essentially, these three parts form the structure of the algorithm.
The trie is represented as an adjacency list. One row of the adjacency list represents one node, and the index of the row is the unique id of that node. The row contains a dict {'value':'', 'next_states':[],'fail_state':0,'output':[]}
where value is the character the node represents ('a','b', '#', '$', etc), next_states
is a list of the id's of the child nodes, fail_state
is the id of the fail state, and output is a list of all the complete keywords we have encountered so far as we have gone through the input text (in this implementation, we can add the same word multiple times in the trie).
We initialize the trie, called AdjList
, and add the root node. We have the keywords, which we will add one by one into the trie.
1 from collections import deque
2 AdjList = []
3
4 def init_trie(keywords):
5 """ creates a trie of keywords, then sets fail transitions """
6 create_empty_trie()
7 add_keywords(keywords)
8 set_fail_transitions()
9
10 def create_empty_trie():
11 """ initalize the root of the trie """
12 AdjList.append({'value':'', 'next_states':[],'fail_state':0,'output':[]})
13
14 def add_keywords(keywords):
15 """ add all keywords in list of keywords """
16 for keyword in keywords:
17 add_keyword(keyword)
We also write a helper find_next_state
which takes a node and a value, and returns the id of the child of that node whose value matches value, or else None
if none found.
1 def find_next_state(current_state, value):
2 for node in AdjList[current_state]["next_states"]:
3 if AdjList[node]["value"] == value:
4 return node
5 return None
Note that this trie only handles lowercase words, for simplicity for my testing. To add a keyword into the trie, we traverse the longest prefix of the keyword that exists in the trie starting from the root, then we add the characters of the rest of the keyword as nodes in the trie, in a chain.
1 def add_keyword(keyword):
2 """ add a keyword to the trie and mark output at the last node """
3 current_state = 0
4 j = 0
5 keyword = keyword.lower()
6 child = find_next_state(current_state, keyword[j])
7 while child != None:
8 current_state = child
9 j = j + 1
10 if j < len(keyword):
11 child = find_next_state(current_state, keyword[j])
12 else:
13 break
14 for i in range(j, len(keyword)):
15 node = {'value':keyword[i],'next_states':[],'fail_state':0,'output':[]}
16 AdjList.append(node)
17 AdjList[current_state]["next_states"].append(len(AdjList) - 1)
18 current_state = len(AdjList) - 1
19 AdjList[current_state]["output"].append(keyword)
The while loop finds the longest prefix of the keyword which exists in the trie so far, and will exit when we can no longer match more characters at index j. The for loop goes through the rest of the keyword, creating a new node for each character and appending it to AdjList
. len(AdjList) - 1
gives the id of the node we are appending, since we are adding to the end of AdjList
.
When we have completed adding the keyword in the trie, AdjList[current_state]["output"].append(keyword)
will append the keyword to the output of the last node, to mark the end of the keyword at that node.
Now, to set the fail transitions. We will do a breadth first search over the trie and set the failure state of each node. First, we set all the children of the root to have the failure state of the root, since the longest strict suffix of a character would be the empty string, represented by the root. The failure state of the root doesn't matter, since when we get to the root, we will just leave the loop, but we can just set it to be the root itself.
Remember that the failure state indicates the end of the next longest proper suffix of the string that we have currently matched.
Consider the node r
. We are setting the failure state for node child
of r
. Initially the potential parent of the fail state of child
, state
will be the next longest proper suffix, which is marked by r
's fail state. If there is no transition from r
's fail state to a node with the same value as child
, then we go to the next longest proper suffix, which is the fail state of r
's fail state, and so on, until we find one which works, or we are at the root.
We set child
's fail state to be this fail state.
We append the output of the fail state to child
's output because since the fail state is a suffix of the string which ends at child
, whatever matched words found at the fail state will also occur at child
. If we did not keep this line, we would miss out on substrings of the currently matched string which are keywords.
1 def set_fail_transitions():
2 q = deque()
3 child = 0
4 for node in AdjList[0]["next_states"]:
5 q.append(node)
6 AdjList[node]["fail_state"] = 0
7 while q:
8 r = q.popleft()
9 for child in AdjList[r]["next_states"]:
10 q.append(child)
11 state = AdjList[r]["fail_state"]
12 while find_next_state(state, AdjList[child]["value"]) == None \
13 and state != 0:
14 state = AdjList[state]["fail_state"]
15 AdjList[child]["fail_state"] = find_next_state(state,
16 AdjList[child]["value"])
17 if AdjList[child]["fail_state"] is None:
18 AdjList[child]["fail_state"] = 0
19 AdjList[child]["output"] = AdjList[child]["output"] +
20 AdjList[AdjList[child]["fail_state"]]["output"]
Finally, our trie is constructed. Given an input, line, we iterate through each character in line, going up to the fail state when we no longer match the next character in line. At each node, we check to see if there is any output, and we will capture all the outputted words and their respective indices. (i-len(j) + 1
is for writing an index at the beginning of the word)
1 def get_keywords_found(line):
2 """ returns true if line contains any keywords in trie """
3 line = line.lower()
4 current_state = 0
5 keywords_found = []
6
7 for i in range(len(line)):
8 while find_next_state(current_state, line[i]) is None and current_state != 0:
9 current_state = AdjList[current_state]["fail_state"]
10 current_state = find_next_state(current_state, line[i])
11 if current_state is None:
12 current_state = 0
13 else:
14 for j in AdjList[current_state]["output"]:
15 keywords_found.append({"index":i-len(j) + 1,"word":j})
16 return keywords_found
Yay! We are done!
Test it like so:
1 init_trie(['cash', 'shew', 'ew'])
2 print get_keywords_found("cashew")
As always, leave questions and concerns in the comments below. See you next time!