根据AC算法的原理([算法]AC算法原理),我们需要定义几个类

  • Keyword(关键字):构建前缀树用的关键字,关键字并不是一个单词,而是一个序列(或者说是一个列表),关键字可以是字符串(字符序列),可以是其他类型的列表,比如 [1,2,3,4] 。关键字里的每一个元素,作为状态转移的条件

  • TrieNode(状态节点):记录节点的跳转条件和目标节点引用、失败节点引用、节点数据等

  • Trie(前缀树):主要方法都在这

定义 Keyword,因为 Keyword 可以是字符串、Integer数组、其他对象数组等等,所以我们定义一个接口

public interface Keyword<T> {
    /**
     * 获取本关键字的单个字的列表
     */
    List<T> getWordSequence();
}

定义一个 String 类型的关键字,看下面的类定义可以知道,所谓的String类型关键字,其实就是char列表

public class StringKeyword implements Keyword<Character> {

    private final List<Character> wordSequence;


    public StringKeyword(@NonNull String str) {
        wordSequence = new ArrayList<>(str.length());
        for (int i = 0; i < str.length(); i++) {
            wordSequence.add(str.charAt(i));
        }
    }


    @Override
    public List<Character> getWordSequence() {
        return wordSequence;
    }
}

接下来是 TrieNode

@Data
public class TrieNode<K, T> {
    /**
     * 是否是结束节点
     */
    private boolean end;

    /**
     * 当前节点的数据,这是一个NodeData对象
     */
    private T nodeData;

    /**
     * 深度
     */
    private int deep = 0;

    /**
     * 根据字,成功转到的下一个状态节点
     */
    private Map<K, TrieNode<K, T>> next = new HashMap<>();

    /**
     * 失败时跳转的下一个状态
     */
    private TrieNode<K, T> failNode;

}

最后最重要的 Trie,关键的几个方法

  • ​addWord(Keyword<K> keyword, V data)​:构建关键字和节点数据的关系

  • ​makeAutomation()​:构建失败转移关系

  • ​match(Keyword<K> keyword)​:传入关键字,找结果

public class Trie<K, V> {

    /**
     * 是否已经创建了自动机
     */
    private boolean alreadyInit = false;

    /**
     * 根节点
     */
    private final TrieNode<K, V> root = new TrieNode<>();

    /**
     * 节点列表
     */
    private final List<TrieNode<K, V>> nodeList = new LinkedList<>();

    /**
     * 深度是1的节点列表
     */
    private final LinkedList<TrieNode<K, V>> oneDeepNodeList = new LinkedList<>();

    public Trie() {
        this.nodeList.add(this.root);
    }

    /**
     * 添加一个关键字以及它对应的数据
     *
     * @param keyword 关键字
     * @param data    关键字对应的数字
     */
    public void addWord(Keyword<K> keyword, V data) {
        if (this.alreadyInit) {
            throw new RuntimeException("Automaton created, no more can be added.");
        }

        List<K> wordSequence = keyword.getWordSequence();
        int len = wordSequence.size();
        TrieNode<K, V> pos = this.root;
        for (int i = 0; i < len; i++) {
            K word = wordSequence.get(i);
            TrieNode<K, V> node;
            if (!pos.getNext().containsKey(word)) {
                node = new TrieNode<>();
                node.setDeep(i + 1);
                this.nodeList.add(node);

                if (node.getDeep() == 1) {
                    node.setFailNode(this.root);
                    this.oneDeepNodeList.add(node);
                }

                pos.getNext().put(word, node);
            } else {
                node = pos.getNext().get(word);
            }

            if (i == len - 1) {
                node.setEnd(true);
                node.setNodeData(data);
            }

            pos = pos.getNext().get(word);
        }
    }

    /**
     * 检查关键字是否存在
     *
     * @param keyword 关键字
     * @return 存在则返回true,否则返回false
     */
    public boolean checkWord(Keyword<K> keyword) {
        int len = keyword.getWordSequence().size();
        TrieNode<K, V> pos = this.root;
        boolean ok = false;
        for (int i = 0; i < len; i++) {
            K word = keyword.getWordSequence().get(i);
            if (pos.getNext().containsKey(word)) {
                pos = pos.getNext().get(word);
            } else {
                break;
            }

            if (i == len - 1 && pos.isEnd()) {
                ok = true;
            }
        }

        return ok;
    }

    /**
     * 获取关键字的数据
     *
     * @param keyword 关键字
     * @return 关键字存在则返回对应节点的数据对象,否则返回null
     */
    public V getNodeData(Keyword<K> keyword) {
        int len = keyword.getWordSequence().size();
        TrieNode<K, V> pos = this.root;
        V data = null;
        for (int i = 0; i < len; i++) {
            K word = keyword.getWordSequence().get(i);
            if (pos.getNext().containsKey(word)) {
                pos = pos.getNext().get(word);
            } else {
                break;
            }

            if (i == len - 1 && pos.isEnd()) {
                data = pos.getNodeData();
            }
        }

        return data;
    }

    /**
     * 修改关键字对应的数据对象(NodeData)的数据
     *
     * @param keyword 关键字
     * @param data    数据
     * @return 关键字存在则返回true,否则返回false
     */
    public boolean setNodeData(Keyword<K> keyword, V data) {
        int len = keyword.getWordSequence().size();
        TrieNode<K, V> pos = this.root;
        boolean ok = false;
        for (int i = 0; i < len; i++) {
            K word = keyword.getWordSequence().get(i);
            if (pos.getNext().containsKey(word)) {
                pos = pos.getNext().get(word);
            } else {
                break;
            }

            if (i == len - 1 && pos.isEnd()) {
                pos.setNodeData(data);
                ok = true;
            }
        }

        return ok;
    }

    /**
     * 构建自动机,计算每个节点的失败跳转节点
     */
    public void makeAutomation() {
        if (alreadyInit) {
            return;
        }

        // 队列初始化为全部是深度为1的节点
        LinkedList<TrieNode<K, V>> nodeQueue = this.oneDeepNodeList;

        while (!nodeQueue.isEmpty()) {
            TrieNode<K, V> r = nodeQueue.pollFirst();

            r.getNext().forEach((word, s) -> {
                nodeQueue.addLast(s);

                // 寻找s的失败跳转节点
                TrieNode<K, V> failNode = r.getFailNode();

                // 如果这个节点输入word也没有可到达的节点,并且不是根节点,则继续找这个节点的失败跳转节点
                while (!failNode.getNext().containsKey(word) && failNode.getDeep() != 0) {
                    failNode = failNode.getFailNode();
                }

                s.setFailNode(failNode.getNext().getOrDefault(word, failNode));
            });
        }

        this.alreadyInit = true;
    }

    /**
     * 对关键字进行匹配
     *
     * @param keyword 关键字
     * @return 返回列表,元素是匹配到的关键字对应的开始位置以及关键字对应的数据
     */
    public List<MatchResult<V>> match(Keyword<K> keyword) {
        List<MatchResult<V>> result = new ArrayList<>();
        int len = keyword.getWordSequence().size();

        TrieNode<K, V> node = this.root;
        for (int i = 0; i < len; i++) {
            K word = keyword.getWordSequence().get(i);
            while (!node.getNext().containsKey(word) && node.getDeep() != 0) {
                node = node.getFailNode();
            }

            if (node.getNext().containsKey(word)) {
                node = node.getNext().get(word);
                if (node.isEnd()) {
                    result.add(MatchResult.<V>builder()
                            .startIndex(i + 1 - node.getDeep())
                            .data(node.getNodeData())
                            .build());
                }
            }
        }

        return result;
    }
}

最后最后,写个单元测试验证一下是否可行

public class AcTest {

    @Test
    public void test() {
        Trie<Character, String> trie = new Trie<>();
        trie.addWord(new StringKeyword("西红柿"), "西红柿data");
        trie.addWord(new StringKeyword("菠萝"), "菠萝data");
        trie.addWord(new StringKeyword("西红柿子"), "西红柿子data");
        trie.addWord(new StringKeyword("西瓜"), "西瓜data");
        trie.addWord(new StringKeyword("西瓜汁"), "西瓜汁data");

        trie.makeAutomation();

        List<MatchResult<String>> result = trie.match(new StringKeyword("我喜欢吃西红柿子和西瓜汁哈哈哈"));
        for (MatchResult<String> r : result) {
            System.out.println(r.getStartIndex() + ": " + r.getData());
        }
    }

}

输出

4: 西红柿data
4: 西红柿子data
9: 西瓜data
9: 西瓜汁data

表示"我喜欢吃西红柿子和西瓜汁哈哈哈"这个字符串中,命中关键字的偏移量位置以及命中了关键字对应的数据值

完整代码在这:https://gitee.com/kk_small_source/algorithm-learning/tree/master/ac/src/main/java/com/kk/learning/algorithm/ac