跳转至

Trie树(前缀树)完全详解

重要性:⭐⭐⭐⭐⭐ 难度:⭐⭐⭐ 学习时间:2-3天 前置知识:树、字符串、哈希表


📚 目录

  1. Trie树基础
  2. Trie树实现
  3. Trie树操作
  4. Trie树变种
  5. 应用场景
  6. LeetCode题目详解
  7. 实战案例

Trie树基础

字典树结构示意图

什么是Trie树?

Trie树(发音同"try"),又称前缀树(Prefix Tree)或字典树,是一种用于高效存储和检索字符串的树形数据结构。

生活中的例子

想象一个自动补全系统

Text Only
你输入"app",系统提示:
- apple
- application
- apply
- appointment
- appreciate

Trie树可以在O(m)时间内找到所有以"app"开头的单词,
其中m是前缀长度,与词典中单词总数无关!

为什么学Trie树?

高效的前缀匹配: - 插入:O(m),m是单词长度 - 搜索:O(m) - 前缀匹配:O(m)

空间优化: - 共享前缀,节省空间 - 如 "apple" 和 "application" 共享 "app"

应用广泛: - 自动补全/搜索建议 - 拼写检查 - IP路由(最长前缀匹配) - 单词搜索游戏

Trie树 vs 哈希表

操作 Trie树 哈希表
插入 O(m) O(1) 平均
精确查找 O(m) O(1) 平均
前缀查找 O(m) O(n) 需要遍历
自动补全 O(m + k) O(n)
空间 共享前缀 独立存储

结论:需要前缀匹配时,Trie树更优!


Trie树实现

节点结构

Text Only
TrieNode:
├── children: 字典/数组,存储子节点
├── is_end: 布尔值,标记是否为单词结尾
└── count: 可选,经过该节点的单词数

可视化

Text Only
                    root
                   /    \
                  a      b
                 /        \
                p          a
               / \          \
              p   r          t
             /     \          \
            l       o          t
           /         \          \
          e           t          l
         /             \          \
        #               #          e
                                    \
                                     #

# 表示单词结尾
存储的单词: apple, app, bat, battle

Python实现

Python
class TrieNode:
    """Trie树节点"""
    def __init__(self):
        self.children = {}  # 子节点字典
        self.is_end = False  # 是否为单词结尾
        self.count = 0  # 经过该节点的单词数(前缀计数)
        self.end_count = 0  # 以该节点结尾的单词数(精确计数)

class Trie:
    """
    Trie树实现
    时间复杂度:
    - 插入: O(m),m为单词长度
    - 搜索: O(m)
    - 前缀匹配: O(m)
    空间复杂度:O(N * m),N为单词数,m为平均长度
    """
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word: str) -> None:
        """插入单词"""
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
            node.count += 1
        node.is_end = True
        node.end_count += 1

    def search(self, word: str) -> bool:
        """精确搜索单词"""
        node = self.root
        for char in word:
            if char not in node.children:
                return False
            node = node.children[char]
        return node.is_end

    def startsWith(self, prefix: str) -> bool:
        """搜索前缀"""
        node = self.root
        for char in prefix:
            if char not in node.children:
                return False
            node = node.children[char]
        return True

    def countWordsEqualTo(self, word: str) -> int:
        """统计等于word的单词数(支持重复插入)"""
        node = self.root
        for char in word:
            if char not in node.children:
                return 0
            node = node.children[char]
        # 返回以该节点结尾的单词数(使用end_count而非count,避免前缀路径计数污染)
        return node.end_count

    def countWordsStartingWith(self, prefix: str) -> int:
        """统计以prefix开头的单词数"""
        node = self.root
        for char in prefix:
            if char not in node.children:
                return 0
            node = node.children[char]
        return node.count

    def erase(self, word: str) -> None:
        """删除单词(假设word存在)"""
        node = self.root
        for char in word:
            node = node.children[char]
            node.count -= 1
        node.end_count -= 1
        if node.end_count == 0:
            node.is_end = False

    def getAllWords(self, node=None, prefix="", result=None):
        """获取所有单词(DFS遍历)"""
        if result is None:
            result = []
            node = self.root

        if node.is_end:
            result.append(prefix)

        for char, child in node.children.items():
            self.getAllWords(child, prefix + char, result)

        return result

    def getWordsWithPrefix(self, prefix: str):
        """获取所有以prefix开头的单词"""
        node = self.root
        for char in prefix:
            if char not in node.children:
                return []
            node = node.children[char]

        return self.getAllWords(node, prefix)

# 测试
trie = Trie()
trie.insert("apple")
trie.insert("app")
trie.insert("application")
trie.insert("banana")
trie.insert("band")

print(trie.search("apple"))      # True
print(trie.search("app"))        # True
print(trie.search("appl"))       # False
print(trie.startsWith("app"))    # True
print(trie.startsWith("ban"))    # True
print(trie.getWordsWithPrefix("app"))  # ['apple', 'app', 'application']
print(trie.getAllWords())        # ['apple', 'app', 'application', 'banana', 'band']

C++实现(竞赛风格)

C++
#include <bits/stdc++.h>  // 引入头文件
using namespace std;

/**
 * Trie树节点结构
 */
struct TrieNode {  // struct结构体:自定义复合数据类型
    TrieNode* children[26];  // 假设只有小写字母
    bool isEnd;              // 是否为单词结尾
    int count;               // 经过该节点的单词数

    TrieNode() {
        memset(children, 0, sizeof(children));
        isEnd = false;
        count = 0;
    }
};

/**
 * Trie树实现
 * 时间复杂度:
 * - 插入: O(m),m为单词长度
 * - 搜索: O(m)
 * - 前缀匹配: O(m)
 */
class Trie {
private:
    TrieNode* root;

public:
    Trie() {
        root = new TrieNode();
    }

    // 插入单词
    void insert(string word) {
        TrieNode* node = root;
        for (char c : word) {
            int idx = c - 'a';
            if (!node->children[idx]) {
                node->children[idx] = new TrieNode();
            }
            node = node->children[idx];
            node->count++;
        }
        node->isEnd = true;
    }

    // 精确搜索单词
    bool search(string word) {
        TrieNode* node = root;
        for (char c : word) {
            int idx = c - 'a';
            if (!node->children[idx]) {
                return false;
            }
            node = node->children[idx];
        }
        return node->isEnd;
    }

    // 搜索前缀
    bool startsWith(string prefix) {
        TrieNode* node = root;
        for (char c : prefix) {
            int idx = c - 'a';
            if (!node->children[idx]) {
                return false;
            }
            node = node->children[idx];
        }
        return true;
    }

    // 统计以prefix开头的单词数
    int countWordsStartingWith(string prefix) {
        TrieNode* node = root;
        for (char c : prefix) {
            int idx = c - 'a';
            if (!node->children[idx]) {
                return 0;
            }
            node = node->children[idx];
        }
        return node->count;
    }

    // 获取所有单词(DFS)
    void getAllWordsHelper(TrieNode* node, string prefix, vector<string>& result) {
        if (node->isEnd) {
            result.push_back(prefix);
        }
        for (int i = 0; i < 26; i++) {
            if (node->children[i]) {
                char c = 'a' + i;
                getAllWordsHelper(node->children[i], prefix + c, result);
            }
        }
    }

    vector<string> getAllWords() {
        vector<string> result;
        getAllWordsHelper(root, "", result);
        return result;
    }

    // 获取以prefix开头的所有单词
    vector<string> getWordsWithPrefix(string prefix) {
        TrieNode* node = root;
        for (char c : prefix) {
            int idx = c - 'a';
            if (!node->children[idx]) {
                return {};
            }
            node = node->children[idx];
        }

        vector<string> result;
        getAllWordsHelper(node, prefix, result);
        return result;
    }
};

// 测试
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    Trie trie;
    trie.insert("apple");
    trie.insert("app");
    trie.insert("application");
    trie.insert("banana");
    trie.insert("band");

    cout << boolalpha;
    cout << "search(apple): " << trie.search("apple") << endl;      // true
    cout << "search(app): " << trie.search("app") << endl;          // true
    cout << "search(appl): " << trie.search("appl") << endl;        // false
    cout << "startsWith(app): " << trie.startsWith("app") << endl;  // true
    cout << "startsWith(ban): " << trie.startsWith("ban") << endl;  // true

    cout << "\nWords with prefix 'app':" << endl;
    vector<string> words = trie.getWordsWithPrefix("app");
    for (const string& w : words) {
        cout << "  " << w << endl;
    }

    cout << "\nAll words:" << endl;
    words = trie.getAllWords();
    for (const string& w : words) {
        cout << "  " << w << endl;
    }

    return 0;
}

Trie树操作

1. 插入操作详解

过程

Text Only
插入 "apple":

1. 从根节点开始
2. 'a': 创建新节点
3. 'p': 创建新节点
4. 'p': 创建新节点
5. 'l': 创建新节点
6. 'e': 创建新节点,标记is_end=true

插入 "app":

1. 从根节点开始
2. 'a': 已存在,复用
3. 'p': 已存在,复用
4. 'p': 已存在,复用
5. 标记is_end=true(已存在节点)

2. 搜索操作详解

Python
def search_visualization(trie, word):
    """搜索过程可视化"""
    print(f"\n搜索 '{word}':")
    node = trie.root
    path = ["root"]

    for i, char in enumerate(word):
        if char not in node.children:
            print(f"  字符 '{char}' 不存在!")
            return False
        node = node.children[char]
        path.append(char)
        print(f"  步骤 {i+1}: 找到 '{char}' -> 路径: {' -> '.join(path)}")

    if node.is_end:
        print(f"  ✓ 找到完整单词 '{word}'")
        return True
    else:
        print(f"  ✗ '{word}' 只是前缀,不是完整单词")
        return False

# 可视化演示
trie = Trie()
trie.insert("apple")
trie.insert("app")

search_visualization(trie, "apple")
search_visualization(trie, "app")
search_visualization(trie, "appl")

3. 前缀匹配

Python
def prefix_match_demo(trie, prefix):
    """前缀匹配演示"""
    print(f"\n前缀匹配 '{prefix}':")

    # 找到前缀节点
    node = trie.root
    for char in prefix:
        if char not in node.children:
            print(f"  前缀 '{prefix}' 不存在")
            return []
        node = node.children[char]

    print(f"  找到前缀节点,开始DFS收集单词...")
    words = trie.getAllWords(node, prefix)
    print(f"  找到 {len(words)} 个单词: {words}")
    return words

# 演示
trie = Trie()
words = ["apple", "app", "application", "apply", "apt", "bat", "ball"]
for w in words:
    trie.insert(w)

prefix_match_demo(trie, "app")
prefix_match_demo(trie, "ba")

Trie树变种

1. 压缩Trie(Compressed Trie)

优化:合并只有一个子节点的链

Text Only
标准Trie:        压缩Trie:
    a               a
    |               |
    p               ppl
    |               / \
    p              e   i
   / \             #   c
  l   i               #
 /     \
e       c
#       #

存储: apple, application
Python
class CompressedTrieNode:
    """压缩Trie节点"""
    def __init__(self):
        self.children = {}  # 子节点,键是字符串片段
        self.is_end = False

class CompressedTrie:
    """压缩Trie树"""
    def __init__(self):
        self.root = CompressedTrieNode()

    def insert(self, word):
        """插入单词"""
        node = self.root
        i = 0

        while i < len(word):
            # 查找匹配的子节点
            found = False
            for key in list(node.children.keys()):
                # 找到最长公共前缀
                j = 0
                while j < len(key) and i + j < len(word) and key[j] == word[i + j]:
                    j += 1

                if j > 0:  # 有公共前缀
                    found = True
                    if j == len(key):  # 完全匹配key
                        node = node.children[key]
                        i += j
                    else:  # 需要分裂
                        # 创建中间节点
                        old_child = node.children.pop(key)

                        # 分裂key
                        common = key[:j]
                        rest_key = key[j:]
                        rest_word = word[i+j:]

                        # 新中间节点
                        mid_node = CompressedTrieNode()
                        node.children[common] = mid_node

                        # 重新连接
                        if rest_key:
                            mid_node.children[rest_key] = old_child
                        else:
                            mid_node.children = old_child.children
                            mid_node.is_end = old_child.is_end

                        if rest_word:
                            new_node = CompressedTrieNode()
                            mid_node.children[rest_word] = new_node
                            node = new_node
                        else:
                            node = mid_node

                        i = len(word)
                    break

            if not found:
                # 创建新分支
                new_node = CompressedTrieNode()
                node.children[word[i:]] = new_node
                node = new_node
                break

        node.is_end = True

2. 带权Trie(用于自动补全排序)

Python
class WeightedTrieNode:
    """带权Trie节点"""
    def __init__(self):
        self.children = {}
        self.is_end = False
        self.weight = 0  # 单词权重(如搜索频率)

class WeightedTrie:
    """带权Trie树,支持按权重排序的自动补全"""
    def __init__(self):
        self.root = WeightedTrieNode()

    def insert(self, word, weight=0):
        """插入带权重的单词"""
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = WeightedTrieNode()
            node = node.children[char]
        node.is_end = True
        node.weight = max(node.weight, weight)  # 保留最大权重

    def autocomplete(self, prefix, k=5):
        """
        自动补全:返回权重最高的k个单词
        时间:O(m + n log k),m为前缀长度,n为候选数
        """
        # 找到前缀节点
        node = self.root
        for char in prefix:
            if char not in node.children:
                return []
            node = node.children[char]

        # 收集所有候选
        candidates = []
        self._collect(node, prefix, candidates)

        # 按权重排序,返回前k个
        candidates.sort(key=lambda x: -x[1])  # 降序
        return [word for word, weight in candidates[:k]]

    def _collect(self, node, prefix, result):
        """DFS收集所有单词及其权重"""
        if node.is_end:
            result.append((prefix, node.weight))

        for char, child in node.children.items():
            self._collect(child, prefix + char, result)

# 测试:搜索引擎自动补全
search_trie = WeightedTrie()
search_trie.insert("apple", 1000)
search_trie.insert("application", 800)
search_trie.insert("apply", 600)
search_trie.insert("appointment", 400)
search_trie.insert("app", 1200)

print(search_trie.autocomplete("app", 3))
# ['app', 'apple', 'application'](按权重排序)

应用场景

1. 自动补全系统

Python
class AutocompleteSystem:
    """
    自动补全系统
    应用场景:搜索引擎、IDE代码补全、输入法
    """
    def __init__(self):
        self.trie = WeightedTrie()
        self.history = {}  # 用户历史搜索记录

    def add_word(self, word, base_weight=0):
        """添加单词到词典"""
        # 基础权重 + 历史权重
        weight = base_weight + self.history.get(word, 0)
        self.trie.insert(word, weight)

    def search(self, prefix, k=5):
        """搜索建议"""
        return self.trie.autocomplete(prefix, k)

    def record_search(self, word):
        """记录用户搜索,增加权重"""
        self.history[word] = self.history.get(word, 0) + 1
        # 更新Trie中的权重
        self.trie.insert(word, self.history[word])

# 使用示例
auto = AutocompleteSystem()

# 初始化词典
words = [
    ("python", 1000),
    ("python tutorial", 800),
    ("python download", 600),
    ("pytorch", 900),
    ("pycharm", 700),
]

for word, weight in words:
    auto.add_word(word, weight)

# 用户输入
print("输入 'py' 的建议:", auto.search("py", 3))
# 用户选择
auto.record_search("python")
print("再次输入 'py' 的建议:", auto.search("py", 3))

2. 拼写检查

Python
class SpellChecker:
    """
    拼写检查器
    基于编辑距离(Levenshtein Distance)
    """
    def __init__(self, dictionary):
        self.trie = Trie()
        for word in dictionary:
            self.trie.insert(word)

    def check(self, word):
        """检查单词拼写"""
        if self.trie.search(word):
            return True, word

        # 查找相似单词(编辑距离<=2)
        suggestions = self._find_similar(word, max_distance=2)
        return False, suggestions

    def _find_similar(self, target, max_distance=2):
        """查找相似单词"""
        # 简化实现:返回前缀匹配的单词
        # 实际应用需要使用更复杂的算法
        prefix = target[:2] if len(target) >= 2 else target
        candidates = self.trie.getWordsWithPrefix(prefix)

        # 计算编辑距离,返回最相似的
        from difflib import SequenceMatcher

        def similarity(a, b):
            return SequenceMatcher(None, a, b).ratio()

        candidates.sort(key=lambda w: similarity(w, target), reverse=True)
        return candidates[:5]

# 测试
dictionary = ["apple", "application", "apply", "banana", "band", "bat"]
checker = SpellChecker(dictionary)

is_correct, result = checker.check("aple")
print(f"'aple' 拼写正确: {is_correct}")
print(f"建议: {result}")

3. IP路由最长前缀匹配

Python
class IPRouter:
    """
    IP路由表(最长前缀匹配)
    应用场景:路由器、网络设备
    """
    def __init__(self):
        self.trie = Trie()
        self.routes = {}  # 前缀 -> 下一跳

    def add_route(self, prefix, next_hop):
        """
        添加路由
        prefix: 如 "192.168.1.0/24"
        next_hop: 下一跳地址
        """
        # 将IP转换为二进制字符串
        binary_prefix = self._ip_to_binary(prefix)
        self.trie.insert(binary_prefix)
        self.routes[binary_prefix] = next_hop

    def route(self, ip):
        """
        查找IP对应的下一跳
        返回最长前缀匹配的结果
        """
        binary_ip = self._ip_to_binary(ip, full=True)

        # 从长到短尝试匹配
        node = self.trie.root
        longest_match = None

        for i, bit in enumerate(binary_ip):  # enumerate同时获取索引和值
            if bit not in node.children:
                break
            node = node.children[bit]
            if node.is_end:
                longest_match = binary_ip[:i+1]

        if longest_match:
            return self.routes.get(longest_match, "No route")
        return "No route"

    def _ip_to_binary(self, ip_with_prefix, full=False):
        """将IP地址转换为二进制字符串"""
        if '/' in ip_with_prefix:
            ip, prefix_len = ip_with_prefix.split('/')
            prefix_len = int(prefix_len)
        else:
            ip = ip_with_prefix
            prefix_len = 32 if full else 0

        parts = ip.split('.')
        binary = ''
        for part in parts:
            binary += format(int(part), '08b')

        return binary[:prefix_len] if not full else binary

# 测试
router = IPRouter()
router.add_route("192.168.0.0/16", "Router A")
router.add_route("192.168.1.0/24", "Router B")
router.add_route("10.0.0.0/8", "Router C")

print(router.route("192.168.1.100"))  # Router B(最长前缀匹配)
print(router.route("192.168.2.100"))  # Router A
print(router.route("10.0.1.1"))       # Router C

LeetCode题目详解

题目1:实现Trie(前缀树)

题目链接LeetCode 208

Python
class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False

class Trie:
    """
    实现Trie树
    时间:插入O(m),搜索O(m),前缀匹配O(m)
    空间:O(N * m)
    """
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word: str) -> None:
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end = True

    def search(self, word: str) -> bool:
        node = self.root
        for char in word:
            if char not in node.children:
                return False
            node = node.children[char]
        return node.is_end

    def startsWith(self, prefix: str) -> bool:
        node = self.root
        for char in prefix:
            if char not in node.children:
                return False
            node = node.children[char]
        return True

# C++版本
"""
#include <bits/stdc++.h>
using namespace std;

struct TrieNode {
    TrieNode* children[26];
    bool isEnd;

    TrieNode() {
        memset(children, 0, sizeof(children));
        isEnd = false;
    }
};

class Trie {
private:
    TrieNode* root;

public:
    Trie() {
        root = new TrieNode();
    }

    void insert(string word) {
        TrieNode* node = root;
        for (char c : word) {
            int idx = c - 'a';
            if (!node->children[idx]) {
                node->children[idx] = new TrieNode();
            }
            node = node->children[idx];
        }
        node->isEnd = true;
    }

    bool search(string word) {
        TrieNode* node = root;
        for (char c : word) {
            int idx = c - 'a';
            if (!node->children[idx]) return false;
            node = node->children[idx];
        }
        return node->isEnd;
    }

    bool startsWith(string prefix) {
        TrieNode* node = root;
        for (char c : prefix) {
            int idx = c - 'a';
            if (!node->children[idx]) return false;
            node = node->children[idx];
        }
        return true;
    }
};
"""

题目2:单词搜索II

题目链接LeetCode 212

Python
class TrieNode:
    def __init__(self):
        self.children = {}
        self.word = None  # 存储完整单词

class Solution:
    """
    单词搜索II
    时间:O(M * N * 4^L),M,N为board大小,L为单词最大长度
    空间:O(K * L),K为单词数
    """
    def findWords(self, board, words):
        # 构建Trie
        root = TrieNode()
        for word in words:
            node = root
            for char in word:
                if char not in node.children:
                    node.children[char] = TrieNode()
                node = node.children[char]
            node.word = word

        result = []
        m, n = len(board), len(board[0])

        def dfs(i, j, node):
            """DFS搜索"""
            char = board[i][j]
            if char not in node.children:
                return

            next_node = node.children[char]

            # 找到完整单词
            if next_node.word:
                result.append(next_node.word)
                next_node.word = None  # 去重

            # 标记已访问
            board[i][j] = '#'

            # 四个方向搜索
            for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                ni, nj = i + di, j + dj
                if 0 <= ni < m and 0 <= nj < n and board[ni][nj] != '#':
                    dfs(ni, nj, next_node)

            # 恢复
            board[i][j] = char

            # 剪枝:如果next_node没有子节点,删除它
            if not next_node.children:
                node.children.pop(char)

        # 从每个单元格开始搜索
        for i in range(m):
            for j in range(n):
                dfs(i, j, root)

        return result

# C++版本
"""
#include <bits/stdc++.h>
using namespace std;

struct TrieNode {
    TrieNode* children[26];
    string word;

    TrieNode() {
        memset(children, 0, sizeof(children));
    }
};

class Solution {
public:
    vector<string> findWords(vector<vector<char>>& board, vector<string>& words) {
        // 构建Trie
        TrieNode* root = new TrieNode();
        for (string& word : words) {
            TrieNode* node = root;
            for (char c : word) {
                int idx = c - 'a';
                if (!node->children[idx]) {
                    node->children[idx] = new TrieNode();
                }
                node = node->children[idx];
            }
            node->word = word;
        }

        vector<string> result;
        int m = board.size(), n = board[0].size();

        function<void(int, int, TrieNode*)> dfs = [&](int i, int j, TrieNode* node) {
            char c = board[i][j];
            int idx = c - 'a';
            if (!node->children[idx]) return;

            TrieNode* nextNode = node->children[idx];

            if (!nextNode->word.empty()) {
                result.push_back(nextNode->word);
                nextNode->word.clear();
            }

            board[i][j] = '#';

            int dirs[4][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
            for (auto& dir : dirs) {
                int ni = i + dir[0], nj = j + dir[1];
                if (ni >= 0 && ni < m && nj >= 0 && nj < n && board[ni][nj] != '#') {
                    dfs(ni, nj, nextNode);
                }
            }

            board[i][j] = c;

            // 剪枝
            bool hasChild = false;
            for (int k = 0; k < 26; k++) {
                if (nextNode->children[k]) {
                    hasChild = true;
                    break;
                }
            }
            if (!hasChild) {
                node->children[idx] = nullptr;
            }
        };

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                dfs(i, j, root);
            }
        }

        return result;
    }
};
"""

题目3:添加与搜索单词

题目链接LeetCode 211

Python
class WordDictionary:
    """
    添加与搜索单词 - 支持通配符'.'
    时间:添加O(m),搜索最坏O(26^m)
    空间:O(N * m)
    """
    def __init__(self):
        self.root = {}

    def addWord(self, word: str) -> None:
        node = self.root
        for char in word:
            if char not in node:
                node[char] = {}
            node = node[char]
        node['#'] = True  # 标记单词结尾

    def search(self, word: str) -> bool:
        def dfs(node, i):
            if i == len(word):
                return '#' in node

            char = word[i]
            if char != '.':
                if char not in node:
                    return False
                return dfs(node[char], i + 1)
            else:
                # 通配符,尝试所有子节点
                for child in node.values():
                    if child != True and dfs(child, i + 1):
                        return True
                return False

        return dfs(self.root, 0)

# C++版本
"""
#include <bits/stdc++.h>
using namespace std;

class WordDictionary {
private:
    unordered_map<char, WordDictionary*> children;
    bool isEnd = false;

public:
    WordDictionary() {}

    void addWord(string word) {
        WordDictionary* node = this;
        for (char c : word) {
            if (!node->children.count(c)) {
                node->children[c] = new WordDictionary();
            }
            node = node->children[c];
        }
        node->isEnd = true;
    }

    bool search(string word) {
        return search(word, 0);
    }

    bool search(string& word, int i) {
        if (i == word.size()) {
            return isEnd;
        }

        char c = word[i];
        if (c != '.') {
            if (!children.count(c)) return false;
            return children[c]->search(word, i + 1);
        } else {
            for (auto& [ch, child] : children) {
                if (child->search(word, i + 1)) {
                    return true;
                }
            }
            return false;
        }
    }
};
"""

题目4:键值映射

题目链接LeetCode 677

Python
class MapSum:
    """
    键值映射
    时间:插入O(m),求和O(m)
    空间:O(N * m)
    """
    def __init__(self):
        self.root = {}
        self.key_values = {}  # 记录已插入的键值对

    def insert(self, key: str, val: int) -> None:
        # 计算差值(如果是更新操作)
        delta = val - self.key_values.get(key, 0)
        self.key_values[key] = val

        # 更新Trie中所有前缀的和
        node = self.root
        for char in key:
            if char not in node:
                node[char] = {}
            node = node[char]
            # 存储前缀和
            node['sum'] = node.get('sum', 0) + delta

    def sum(self, prefix: str) -> int:
        node = self.root
        for char in prefix:
            if char not in node:
                return 0
            node = node[char]
        return node.get('sum', 0)

# C++版本
"""
#include <bits/stdc++.h>
using namespace std;

class MapSum {
private:
    unordered_map<char, MapSum*> children;
    int sum_val = 0;
    unordered_map<string, int> key_values;

public:
    MapSum() {}

    void insert(string key, int val) {
        int delta = val - key_values[key];
        key_values[key] = val;

        MapSum* node = this;
        for (char c : key) {
            if (!node->children.count(c)) {
                node->children[c] = new MapSum();
            }
            node = node->children[c];
            node->sum_val += delta;
        }
    }

    int sum(string prefix) {
        MapSum* node = this;
        for (char c : prefix) {
            if (!node->children.count(c)) {
                return 0;
            }
            node = node->children[c];
        }
        return node->sum_val;
    }
};
"""

题目5:搜索推荐系统

题目链接LeetCode 1268

Python
class Solution:
    """
    搜索推荐系统
    时间:O(N * L + M * L),N为产品数,L为平均长度,M为搜索词长度
    空间:O(N * L)
    """
    def suggestedProducts(self, products, searchWord):
        # 排序产品(保证字典序)
        products.sort()

        # 构建Trie
        root = {}
        for product in products:
            node = root
            for char in product:
                if char not in node:
                    node[char] = {}
                node = node[char]
                # 存储最多3个产品
                if 'products' not in node:
                    node['products'] = []
                if len(node['products']) < 3:
                    node['products'].append(product)

        # 搜索
        result = []
        node = root

        for char in searchWord:
            if node and char in node:
                node = node[char]
                result.append(node.get('products', []))
            else:
                node = None
                result.append([])

        return result

# C++版本
"""
#include <bits/stdc++.h>
using namespace std;

class Solution {
public:
    vector<vector<string>> suggestedProducts(vector<string>& products, string searchWord) {
        sort(products.begin(), products.end());

        // 构建Trie
        struct TrieNode {
            TrieNode* children[26] = {};
            vector<string> products;
        };

        TrieNode* root = new TrieNode();
        for (string& product : products) {
            TrieNode* node = root;
            for (char c : product) {
                int idx = c - 'a';
                if (!node->children[idx]) {
                    node->children[idx] = new TrieNode();
                }
                node = node->children[idx];
                if (node->products.size() < 3) {
                    node->products.push_back(product);
                }
            }
        }

        // 搜索
        vector<vector<string>> result;
        TrieNode* node = root;

        for (char c : searchWord) {
            int idx = c - 'a';
            if (node && node->children[idx]) {
                node = node->children[idx];
                result.push_back(node->products);
            } else {
                node = nullptr;
                result.push_back({});
            }
        }

        return result;
    }
};
"""

实战案例

应用1:搜索引擎自动补全

Python
class SearchEngine:
    """
    搜索引擎自动补全系统
    完整实现,包含权重计算和个性化推荐
    """
    def __init__(self):
        self.trie = WeightedTrie()
        self.user_history = {}  # 用户搜索历史
        self.global_freq = {}   # 全局搜索频率

    def index_document(self, documents):
        """索引文档,提取关键词"""
        import re
        from collections import Counter  # Counter计数器:统计元素出现次数

        for doc in documents:
            words = re.findall(r'\b[a-z]+\b', doc.lower())
            word_freq = Counter(words)

            for word, freq in word_freq.items():
                self.global_freq[word] = self.global_freq.get(word, 0) + freq
                self.trie.insert(word, self.global_freq[word])

    def search(self, query, user_id=None, k=5):
        """
        搜索建议
        - query: 用户输入
        - user_id: 用户ID(用于个性化)
        - k: 返回建议数量
        """
        # 基础建议
        suggestions = self.trie.autocomplete(query, k * 2)

        # 个性化排序
        if user_id and user_id in self.user_history:
            user_pref = self.user_history[user_id]
            # 根据用户历史调整权重
            suggestions.sort(
                key=lambda w: (user_pref.get(w, 0), self.global_freq.get(w, 0)),  # lambda匿名函数:简洁的单行函数
                reverse=True
            )

        return suggestions[:k]  # 切片操作:[start:end:step]提取子序列

    def record_click(self, query, user_id=None):
        """记录用户点击,更新权重"""
        self.global_freq[query] = self.global_freq.get(query, 0) + 1

        if user_id:
            if user_id not in self.user_history:
                self.user_history[user_id] = {}
            self.user_history[user_id][query] = \
                self.user_history[user_id].get(query, 0) + 1

        # 更新Trie权重
        self.trie.insert(query, self.global_freq[query])

# 使用示例
engine = SearchEngine()

# 索引文档
docs = [
    "python tutorial for beginners",
    "python programming language",
    "python data science",
    "python machine learning",
    "javascript tutorial",
    "java programming"
]
engine.index_document(docs)

# 搜索建议
print("输入 'py' 的建议:", engine.search("py"))
print("输入 'java' 的建议:", engine.search("java"))

应用2:DNA序列分析

Python
class DNASequenceAnalyzer:
    """
    DNA序列分析
    应用场景:生物信息学、基因序列匹配
    """
    def __init__(self):
        self.trie = Trie()
        self.sequences = []

    def add_sequence(self, sequence):
        """添加DNA序列"""
        self.trie.insert(sequence)
        self.sequences.append(sequence)

    def find_common_prefix(self):
        """找到所有序列的最长公共前缀"""
        if not self.sequences:
            return ""

        # 使用Trie找到最长公共前缀
        node = self.trie.root
        prefix = ""

        while len(node.children) == 1 and not node.is_end:
            char = list(node.children.keys())[0]
            prefix += char
            node = node.children[char]

        return prefix

    def find_similar(self, target, max_mismatch=1):
        """
        查找相似序列(允许max_mismatch个错配)
        """
        results = []

        def dfs(node, prefix, i, mismatches):
            if mismatches > max_mismatch:
                return

            if i == len(target):
                if node.is_end:
                    results.append((prefix, mismatches))
                return

            for char, child in node.children.items():
                new_mismatches = mismatches + (char != target[i])
                dfs(child, prefix + char, i + 1, new_mismatches)

        dfs(self.trie.root, "", 0, 0)
        return results

# 测试
dna = DNASequenceAnalyzer()
dna.add_sequence("ATCG")
dna.add_sequence("ATGC")
dna.add_sequence("ATCGG")

print(f"最长公共前缀: {dna.find_common_prefix()}")
print(f"与ATGG相似的序列: {dna.find_similar('ATGG', max_mismatch=1)}")

📝 总结

关键要点

Trie树核心: - 节点存储字符,路径形成单词 - 共享前缀,节省空间 - 插入/搜索/前缀匹配都是O(m)

时间复杂度: - 插入:O(m),m为单词长度 - 搜索:O(m) - 前缀匹配:O(m) - 空间:O(N * m),N为单词数

应用场景: - 自动补全/搜索建议 - 拼写检查 - IP路由(最长前缀匹配) - 单词搜索游戏 - DNA序列分析

变种: - 压缩Trie:合并单分支 - 带权Trie:支持排序 - 后缀树:字符串处理

Trie树 vs 其他数据结构

场景 推荐数据结构 理由
前缀匹配 Trie树 O(m)时间,高效
精确查找 哈希表 O(1)平均时间
范围查询 平衡树 支持顺序遍历
模糊匹配 编辑距离 支持拼写纠错

下一步

继续学习: - 线段树/树状数组 - 区间查询高效数据结构 - 字符串算法 - KMP等高级字符串匹配


📚 扩展阅读