跳转至

单元测试

Python单元测试核心——从基础到精通

📋 本章目标

完成本章学习后,你将能够:

  1. 理解单元测试的概念和重要性
  2. 熟练使用 unittest 框架编写测试
  3. 掌握 pytest 框架的高级用法
  4. 学会编写高质量的测试代码
  5. 理解 Mock 和测试替身的概念
  6. 掌握测试覆盖率工具的使用

1. 单元测试基础

1.1 什么是单元测试

单元测试是对软件中最小的可测试单元进行检查和验证的过程。

Text Only
┌─────────────────────────────────────────────────────────────┐
│                      单元测试定义                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   测试对象: 函数、方法、类                                    │
│   测试范围: 单个功能点                                       │
│   测试人员: 开发人员                                         │
│   测试时机: 编码阶段                                         │
│   自动化: 高度自动化                                         │
│                                                             │
│   ┌─────────┐    ┌─────────┐    ┌─────────┐                │
│   │ 输入    │ ──▶│ 单元    │ ──▶│ 输出    │                │
│   │ (测试数据)│    │ (函数)  │    │ (预期)  │                │
│   └─────────┘    └─────────┘    └─────────┘                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 单元测试的好处

好处 说明
早期发现缺陷 在编码阶段发现问题,修复成本低
代码质量保证 强制编写可测试、模块化的代码
重构信心 有测试保护,放心重构代码
文档作用 测试代码展示了函数的用法
设计改进 难测试的代码往往设计有问题

1.3 FIRST原则

Text Only
优秀的单元测试应遵循FIRST原则:

F - Fast (快速)
    测试应该快速执行,秒级完成

I - Independent (独立)
    测试之间不应有依赖关系

R - Repeatable (可重复)
    在任何环境下都能得到相同结果

S - Self-Validating (自我验证)
    测试结果应该是明确的 Pass/Fail

T - Timely (及时)
    测试应该及时编写,最好在编码前

2. unittest框架

2.1 基本结构

Python
import unittest

class TestStringMethods(unittest.TestCase):
    """字符串方法测试示例"""

    def test_upper(self):
        """测试大写转换"""
        self.assertEqual('hello'.upper(), 'HELLO')

    def test_isupper(self):
        """测试是否大写"""
        self.assertTrue('HELLO'.isupper())
        self.assertFalse('Hello'.isupper())

    def test_split(self):
        """测试字符串分割"""
        s = 'hello world'
        self.assertEqual(s.split(), ['hello', 'world'])
        # 检查分割后列表长度
        with self.assertRaises(TypeError):
            s.split(2)

if __name__ == '__main__':
    unittest.main()

2.2 常用断言方法

Python
import unittest

class TestAssertions(unittest.TestCase):
    """断言方法示例"""

    def test_equality(self):
        """相等性断言"""
        # 相等
        self.assertEqual(1 + 1, 2)
        self.assertEqual("hello", "hello")

        # 不相等
        self.assertNotEqual(1 + 1, 3)

    def test_boolean(self):
        """布尔断言"""
        self.assertTrue(1 == 1)
        self.assertFalse(1 == 2)

    def test_comparison(self):
        """比较断言"""
        self.assertGreater(5, 3)      # 5 > 3
        self.assertLess(3, 5)         # 3 < 5
        self.assertGreaterEqual(5, 5) # 5 >= 5
        self.assertLessEqual(3, 5)    # 3 <= 5

    def test_membership(self):
        """成员断言"""
        self.assertIn(1, [1, 2, 3])
        self.assertNotIn(4, [1, 2, 3])

    def test_type(self):
        """类型断言"""
        self.assertIsInstance("hello", str)
        self.assertIsInstance([1, 2], list)

    def test_none(self):
        """None断言"""
        self.assertIsNone(None)
        self.assertIsNotNone("value")

    def test_exception(self):
        """异常断言"""
        with self.assertRaises(ValueError):
            int("not a number")

        with self.assertRaises(TypeError):
            "string" + 123

    def test_almost_equal(self):
        """浮点数近似相等"""
        self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7)

2.3 测试固件 (Fixtures)

Python
import unittest
import tempfile
import os

class TestFileOperations(unittest.TestCase):
    """文件操作测试 - 展示测试固件"""

    @classmethod  # @classmethod类方法,第一个参数为类本身
    def setUpClass(cls):
        """类级别设置 - 整个测试类执行一次"""
        print("开始测试文件操作...")
        cls.test_dir = tempfile.mkdtemp()

    @classmethod
    def tearDownClass(cls):
        """类级别清理"""
        print("测试完成!")
        import shutil
        shutil.rmtree(cls.test_dir)

    def setUp(self):
        """每个测试方法前执行"""
        self.test_file = os.path.join(self.test_dir, 'test.txt')
        with open(self.test_file, 'w') as f:
            f.write('test content')

    def tearDown(self):
        """每个测试方法后执行"""
        if os.path.exists(self.test_file):
            os.remove(self.test_file)

    def test_file_exists(self):
        """测试文件存在"""
        self.assertTrue(os.path.exists(self.test_file))

    def test_file_content(self):
        """测试文件内容"""
        with open(self.test_file, 'r') as f:
            content = f.read()
        self.assertEqual(content, 'test content')

2.4 测试套件

Python
import unittest

# 创建测试套件
def create_test_suite():
    """创建组合测试套件"""
    suite = unittest.TestSuite()

    # 添加单个测试用例
    suite.addTest(TestStringMethods('test_upper'))

    # 添加整个测试类
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestStringMethods))

    # 从模块加载
    # suite.addTests(unittest.TestLoader().loadTestsFromModule(my_module))

    return suite

# 运行套件
if __name__ == '__main__':
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(create_test_suite())

3. pytest框架

3.1 基本用法

Python
# pytest 不需要继承任何类,更简洁

def test_simple():
    """最简单的测试"""
    assert 1 + 1 == 2

def test_string_operations():
    """字符串操作测试"""
    assert "hello".upper() == "HELLO"
    assert "HELLO".isupper() is True
    assert "hello world".split() == ["hello", "world"]

def test_list_operations():
    """列表操作测试"""
    my_list = [1, 2, 3]
    my_list.append(4)
    assert len(my_list) == 4
    assert 4 in my_list

3.2 pytest断言

Python
import pytest

def test_equality():
    """相等性断言 - pytest直接使用assert"""
    assert 1 + 1 == 2
    assert "hello" == "hello"

def test_approximation():
    """近似值断言"""
    assert 0.1 + 0.2 == pytest.approx(0.3)

def test_exception():
    """异常断言"""
    with pytest.raises(ValueError):
        int("not a number")

    # 检查异常消息
    with pytest.raises(ValueError, match="invalid literal"):
        int("abc")

def test_warning():
    """警告断言"""
    import warnings
    with pytest.warns(UserWarning):
        warnings.warn("test warning", UserWarning)

def test_custom_message():
    """自定义断言失败消息"""
    value = 5
    assert value > 10, f"值 {value} 应该大于10"

3.3 pytest fixtures

Python
import pytest

# 简单fixture
@pytest.fixture
def sample_data():
    """返回测试数据"""
    return {"name": "test", "value": 42}

def test_with_fixture(sample_data):
    """使用fixture的测试"""
    assert sample_data["name"] == "test"
    assert sample_data["value"] == 42

# 带清理的fixture
@pytest.fixture
def temp_file():
    """创建临时文件并在测试后清理"""
    import tempfile
    import os

    # setup
    fd, path = tempfile.mkstemp()
    os.write(fd, b"test content")
    os.close(fd)

    yield path  # 返回给测试使用

    # teardown (yield之后的代码)
    os.unlink(path)

def test_temp_file(temp_file):
    """使用临时文件测试"""
    with open(temp_file, 'r') as f:
        content = f.read()
    assert content == "test content"

# fixture作用域
@pytest.fixture(scope="module")
def database_connection():
    """模块级别fixture - 整个模块只执行一次"""
    print("\n连接数据库...")
    conn = {"connected": True}
    yield conn
    print("\n关闭数据库连接...")
    conn["connected"] = False

def test_query_1(database_connection):
    assert database_connection["connected"] is True

def test_query_2(database_connection):
    assert database_connection["connected"] is True

3.4 参数化测试

Python
import pytest

# 单参数
@pytest.mark.parametrize("input,expected", [
    ("hello", "HELLO"),
    ("WORLD", "WORLD"),
    ("Python", "PYTHON"),
])
def test_upper_param(input, expected):
    """参数化测试 - 大写转换"""
    assert input.upper() == expected

# 多参数
@pytest.mark.parametrize("a,b,expected", [
    (1, 2, 3),
    (5, 5, 10),
    (-1, 1, 0),
    (100, -50, 50),
])
def test_add_param(a, b, expected):
    """参数化测试 - 加法"""
    assert a + b == expected

# 组合参数化
@pytest.mark.parametrize("x", [1, 2])
@pytest.mark.parametrize("y", [10, 20])
def test_combination(x, y):
    """组合参数化 - 会产生4个测试"""
    assert x + y in [11, 12, 21, 22]

# 使用ids命名
@pytest.mark.parametrize("input,expected", [
    ("hello", "HELLO"),
    ("world", "WORLD"),
], ids=["lowercase_hello", "lowercase_world"])
def test_with_ids(input, expected):
    assert input.upper() == expected

3.5 标记和跳过

Python
import sys
import pytest

# 跳过测试
@pytest.mark.skip(reason="功能尚未实现")
def test_future_feature():
    assert False

# 条件跳过
@pytest.mark.skipif(sys.version_info < (3, 8), reason="需要Python 3.8+")
def test_python38_feature():
    # 使用Python 3.8+的特性
    pass

# 标记预期失败
@pytest.mark.xfail(reason="已知bug,待修复")
def test_known_bug():
    assert 1 == 2  # 这个测试预期会失败

# 自定义标记
@pytest.mark.slow
def test_slow_operation():
    """慢速测试"""
    import time
    time.sleep(2)
    assert True

# 运行时选择
# pytest -m slow  # 只运行slow标记的测试
# pytest -m "not slow"  # 运行非slow的测试

4. Mock和测试替身

4.1 Mock基础

Python
from unittest.mock import Mock, MagicMock, patch

def test_mock_basic():
    """Mock基础用法"""
    # 创建Mock对象
    mock = Mock()

    # 设置返回值
    mock.method.return_value = 42
    assert mock.method() == 42

    # 验证调用
    mock.method.assert_called_once()

def test_mock_attributes():
    """Mock属性"""
    mock = Mock()
    mock.name = "test"
    mock.value = 100

    assert mock.name == "test"
    assert mock.value == 100

def test_mock_side_effect():
    """Mock副作用"""
    mock = Mock()

    # 使用函数作为副作用
    mock.side_effect = lambda x: x * 2  # lambda匿名函数:简洁的单行函数
    assert mock(5) == 10
    assert mock(3) == 6

    # 使用异常作为副作用
    mock.side_effect = ValueError("error")
    try:  # try/except捕获异常
        mock()
    except ValueError as e:
        assert str(e) == "error"

    # 使用列表作为副作用(顺序返回)
    mock.side_effect = [1, 2, 3]
    assert mock() == 1
    assert mock() == 2
    assert mock() == 3

4.2 patch装饰器

Python
from unittest.mock import patch, Mock
import requests

# 被测试的函数
def get_user_info(user_id):
    """获取用户信息"""
    response = requests.get(f"https://api.example.com/users/{user_id}")
    return response.json()

# 使用patch模拟requests
@patch('requests.get')
def test_get_user_info(mock_get):
    """测试获取用户信息"""
    # 设置模拟响应
    mock_response = Mock()
    mock_response.json.return_value = {"id": 1, "name": "Test User"}
    mock_get.return_value = mock_response

    # 调用被测函数
    result = get_user_info(1)

    # 验证结果
    assert result == {"id": 1, "name": "Test User"}
    mock_get.assert_called_once_with("https://api.example.com/users/1")

# 上下文管理器方式
def test_with_context_manager():
    """使用上下文管理器"""
    with patch('requests.get') as mock_get:
        mock_get.return_value.json.return_value = {"id": 2}
        result = get_user_info(2)
        assert result["id"] == 2

4.3 pytest-mock插件

Python
# pytest-mock 提供了mocker fixture

def test_with_mocker(mocker):
    """使用mocker fixture"""
    # Mock整个类
    mock_requests = mocker.patch('requests.get')
    mock_requests.return_value.json.return_value = {"status": "ok"}

    # Mock对象方法
    obj = mocker.MagicMock()
    obj.method.return_value = 42

    # Spy - 部分Mock
    # mocker.spy 实际调用会被执行,但可以验证调用情况

5. 测试覆盖率

5.1 使用coverage.py

Bash
# 安装
pip install coverage

# 运行测试并收集覆盖率
coverage run -m pytest

# 生成报告
coverage report

# 生成HTML报告
coverage html

# 查看报告
open htmlcov/index.html

5.2 pytest-cov插件

Bash
# 安装
pip install pytest-cov

# 运行测试并显示覆盖率
pytest --cov=myproject tests/

# 生成HTML报告
pytest --cov=myproject --cov-report=html tests/

5.3 配置覆盖率

INI
# .coveragerc 或 setup.cfg

[run]
source = myproject
omit =
    */tests/*
    */__pycache__/*
    */migrations/*

[report]
exclude_lines =
    pragma: no cover
    def __repr__
    raise NotImplementedError
    if __name__ == .__main__.:
fail_under = 80

6. 高级测试技巧

6.1 测试私有方法

Python
class Calculator:
    """计算器类"""

    def _validate_input(self, value):
        """私有方法 - 验证输入"""
        if not isinstance(value, (int, float)):  # isinstance检查对象类型
            raise TypeError("必须是数字")
        return True

    def add(self, a, b):
        """公共方法"""
        self._validate_input(a)
        self._validate_input(b)
        return a + b

# 测试方式1: 通过公共方法间接测试
def test_add_validates_input():
    calc = Calculator()
    with pytest.raises(TypeError):
        calc.add("not a number", 1)

# 测试方式2: 直接访问私有方法(不推荐,但有时必要)
def test_private_method_directly():
    calc = Calculator()
    calc._validate_input(5)  # 可以访问
    with pytest.raises(TypeError):
        calc._validate_input("invalid")

6.2 测试异常和边界

Python
import pytest

def divide(a, b):
    """除法函数"""
    if b == 0:
        raise ValueError("除数不能为零")
    return a / b

class TestDivide:
    """除法测试类"""

    def test_normal(self):
        """正常情况"""
        assert divide(10, 2) == 5
        assert divide(9, 3) == 3

    def test_negative(self):
        """负数情况"""
        assert divide(-10, 2) == -5
        assert divide(10, -2) == -5

    def test_float(self):
        """浮点数情况"""
        assert divide(5, 2) == 2.5

    def test_zero_divisor(self):
        """除数为零"""
        with pytest.raises(ValueError, match="除数不能为零"):
            divide(10, 0)

    @pytest.mark.parametrize("a,b,expected", [
        (0, 1, 0),      # 被除数为零
        (1, 1, 1),      # 相同数
        (1e10, 1, 1e10), # 大数
        (1e-10, 1, 1e-10), # 小数
    ])
    def test_edge_cases(self, a, b, expected):
        """边界情况"""
        assert divide(a, b) == pytest.approx(expected)

6.3 测试文件系统

Python
import pytest
import tempfile
import os

@pytest.fixture
def temp_dir():
    """临时目录fixture"""
    with tempfile.TemporaryDirectory() as tmpdir:
        yield tmpdir

def test_file_operations(temp_dir):
    """测试文件操作"""
    # 创建文件
    file_path = os.path.join(temp_dir, "test.txt")

    # 写入
    with open(file_path, 'w') as f:  # with自动管理资源,确保文件正确关闭
        f.write("test content")

    # 读取验证
    with open(file_path, 'r') as f:
        content = f.read()
    assert content == "test content"

    # 文件存在
    assert os.path.exists(file_path)

6.4 测试数据库操作

Python
import pytest
import sqlite3

@pytest.fixture
def in_memory_db():
    """内存数据库fixture"""
    conn = sqlite3.connect(':memory:')
    cursor = conn.cursor()

    # 创建测试表
    cursor.execute('''
        CREATE TABLE users (
            id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            email TEXT UNIQUE
        )
    ''')
    conn.commit()

    yield conn  # yield生成器:惰性产出值,节省内存

    conn.close()

def test_insert_user(in_memory_db):
    """测试插入用户"""
    cursor = in_memory_db.cursor()
    cursor.execute(
        "INSERT INTO users (name, email) VALUES (?, ?)",
        ("Test User", "test@example.com")
    )
    in_memory_db.commit()

    cursor.execute("SELECT * FROM users")
    users = cursor.fetchall()
    assert len(users) == 1
    assert users[0][1] == "Test User"

def test_unique_email(in_memory_db):
    """测试邮箱唯一性"""
    cursor = in_memory_db.cursor()

    cursor.execute(
        "INSERT INTO users (name, email) VALUES (?, ?)",
        ("User1", "same@example.com")
    )
    in_memory_db.commit()

    # 插入相同邮箱应该失败
    with pytest.raises(sqlite3.IntegrityError):
        cursor.execute(
            "INSERT INTO users (name, email) VALUES (?, ?)",
            ("User2", "same@example.com")
        )

7. 面试题精选

Q1: unittest和pytest有什么区别?

参考答案

特性 unittest pytest
断言 self.assertEqual() assert
发现 需要继承TestCase 自动发现test_*.py
Fixtures setUp/tearDown @pytest.fixture
参数化 需要subTest @pytest.mark.parametrize
插件生态 较少 丰富(pytest-cov等)

Q2: 什么是测试替身?有哪些类型?

参考答案: 测试替身是用于替代真实依赖的对象:

  1. Dummy:只用于填充参数,不实际使用
  2. Stub:提供预定义的响应
  3. Spy:记录调用信息用于验证
  4. Mock:验证交互行为
  5. Fake:有实际工作实现(如内存数据库)

Q3: 如何测试异步代码?

参考答案

Python
import pytest

@pytest.mark.asyncio
async def test_async_function():  # async def定义异步函数;用await调用
    result = await async_operation()  # await等待异步操作完成
    assert result == expected_value

# 需要安装 pytest-asyncio


8. 最佳实践

8.1 测试命名规范

Python
# 好的命名
def test_add_should_return_sum_of_two_numbers():
    """清晰描述测试意图"""
    pass

def test_divide_should_raise_error_when_divisor_is_zero():
    """包含预期行为和条件"""
    pass

# 不好的命名
def test_add():  # 太笼统
    pass

def test1():     # 无意义
    pass

8.2 AAA模式

Python
def test_with_aaa_pattern():
    """Arrange-Act-Assert 模式"""

    # Arrange (准备)
    calculator = Calculator()
    a, b = 5, 3

    # Act (执行)
    result = calculator.add(a, b)

    # Assert (断言)
    assert result == 8  # assert断言:条件为False时抛出AssertionError

8.3 测试清单

  • 每个公共方法都有对应测试?
  • 测试覆盖正常和异常路径?
  • 测试之间相互独立?
  • 使用有意义的测试名称?
  • Mock了外部依赖?
  • 测试覆盖率达标(≥80%)?

9. 学习检查清单

完成本章学习后,请确认你能够:

  • 使用unittest编写基本测试
  • 使用pytest编写简洁的测试
  • 编写和使用fixtures
  • 使用参数化测试减少重复代码
  • 使用Mock模拟外部依赖
  • 测量和提高测试覆盖率
  • 遵循测试最佳实践

10. 参考资料

官方文档

相关教程


最后更新日期:2026-02-17 适用版本:测试与质量保证教程 v2026