单元测试¶
Python单元测试核心——从基础到精通
📋 本章目标¶
完成本章学习后,你将能够:
- 理解单元测试的概念和重要性
- 熟练使用 unittest 框架编写测试
- 掌握 pytest 框架的高级用法
- 学会编写高质量的测试代码
- 理解 Mock 和测试替身的概念
- 掌握测试覆盖率工具的使用
1. 单元测试基础¶
1.1 什么是单元测试¶
单元测试是对软件中最小的可测试单元进行检查和验证的过程。
Text Only
┌─────────────────────────────────────────────────────────────┐
│ 单元测试定义 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 测试对象: 函数、方法、类 │
│ 测试范围: 单个功能点 │
│ 测试人员: 开发人员 │
│ 测试时机: 编码阶段 │
│ 自动化: 高度自动化 │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 输入 │ ──▶│ 单元 │ ──▶│ 输出 │ │
│ │ (测试数据)│ │ (函数) │ │ (预期) │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 单元测试的好处¶
| 好处 | 说明 |
|---|---|
| 早期发现缺陷 | 在编码阶段发现问题,修复成本低 |
| 代码质量保证 | 强制编写可测试、模块化的代码 |
| 重构信心 | 有测试保护,放心重构代码 |
| 文档作用 | 测试代码展示了函数的用法 |
| 设计改进 | 难测试的代码往往设计有问题 |
1.3 FIRST原则¶
Text Only
优秀的单元测试应遵循FIRST原则:
F - Fast (快速)
测试应该快速执行,秒级完成
I - Independent (独立)
测试之间不应有依赖关系
R - Repeatable (可重复)
在任何环境下都能得到相同结果
S - Self-Validating (自我验证)
测试结果应该是明确的 Pass/Fail
T - Timely (及时)
测试应该及时编写,最好在编码前
2. unittest框架¶
2.1 基本结构¶
Python
import unittest
class TestStringMethods(unittest.TestCase):
"""字符串方法测试示例"""
def test_upper(self):
"""测试大写转换"""
self.assertEqual('hello'.upper(), 'HELLO')
def test_isupper(self):
"""测试是否大写"""
self.assertTrue('HELLO'.isupper())
self.assertFalse('Hello'.isupper())
def test_split(self):
"""测试字符串分割"""
s = 'hello world'
self.assertEqual(s.split(), ['hello', 'world'])
# 检查分割后列表长度
with self.assertRaises(TypeError):
s.split(2)
if __name__ == '__main__':
unittest.main()
2.2 常用断言方法¶
Python
import unittest
class TestAssertions(unittest.TestCase):
"""断言方法示例"""
def test_equality(self):
"""相等性断言"""
# 相等
self.assertEqual(1 + 1, 2)
self.assertEqual("hello", "hello")
# 不相等
self.assertNotEqual(1 + 1, 3)
def test_boolean(self):
"""布尔断言"""
self.assertTrue(1 == 1)
self.assertFalse(1 == 2)
def test_comparison(self):
"""比较断言"""
self.assertGreater(5, 3) # 5 > 3
self.assertLess(3, 5) # 3 < 5
self.assertGreaterEqual(5, 5) # 5 >= 5
self.assertLessEqual(3, 5) # 3 <= 5
def test_membership(self):
"""成员断言"""
self.assertIn(1, [1, 2, 3])
self.assertNotIn(4, [1, 2, 3])
def test_type(self):
"""类型断言"""
self.assertIsInstance("hello", str)
self.assertIsInstance([1, 2], list)
def test_none(self):
"""None断言"""
self.assertIsNone(None)
self.assertIsNotNone("value")
def test_exception(self):
"""异常断言"""
with self.assertRaises(ValueError):
int("not a number")
with self.assertRaises(TypeError):
"string" + 123
def test_almost_equal(self):
"""浮点数近似相等"""
self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7)
2.3 测试固件 (Fixtures)¶
Python
import unittest
import tempfile
import os
class TestFileOperations(unittest.TestCase):
"""文件操作测试 - 展示测试固件"""
@classmethod # @classmethod类方法,第一个参数为类本身
def setUpClass(cls):
"""类级别设置 - 整个测试类执行一次"""
print("开始测试文件操作...")
cls.test_dir = tempfile.mkdtemp()
@classmethod
def tearDownClass(cls):
"""类级别清理"""
print("测试完成!")
import shutil
shutil.rmtree(cls.test_dir)
def setUp(self):
"""每个测试方法前执行"""
self.test_file = os.path.join(self.test_dir, 'test.txt')
with open(self.test_file, 'w') as f:
f.write('test content')
def tearDown(self):
"""每个测试方法后执行"""
if os.path.exists(self.test_file):
os.remove(self.test_file)
def test_file_exists(self):
"""测试文件存在"""
self.assertTrue(os.path.exists(self.test_file))
def test_file_content(self):
"""测试文件内容"""
with open(self.test_file, 'r') as f:
content = f.read()
self.assertEqual(content, 'test content')
2.4 测试套件¶
Python
import unittest
# 创建测试套件
def create_test_suite():
"""创建组合测试套件"""
suite = unittest.TestSuite()
# 添加单个测试用例
suite.addTest(TestStringMethods('test_upper'))
# 添加整个测试类
suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestStringMethods))
# 从模块加载
# suite.addTests(unittest.TestLoader().loadTestsFromModule(my_module))
return suite
# 运行套件
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(create_test_suite())
3. pytest框架¶
3.1 基本用法¶
Python
# pytest 不需要继承任何类,更简洁
def test_simple():
"""最简单的测试"""
assert 1 + 1 == 2
def test_string_operations():
"""字符串操作测试"""
assert "hello".upper() == "HELLO"
assert "HELLO".isupper() is True
assert "hello world".split() == ["hello", "world"]
def test_list_operations():
"""列表操作测试"""
my_list = [1, 2, 3]
my_list.append(4)
assert len(my_list) == 4
assert 4 in my_list
3.2 pytest断言¶
Python
import pytest
def test_equality():
"""相等性断言 - pytest直接使用assert"""
assert 1 + 1 == 2
assert "hello" == "hello"
def test_approximation():
"""近似值断言"""
assert 0.1 + 0.2 == pytest.approx(0.3)
def test_exception():
"""异常断言"""
with pytest.raises(ValueError):
int("not a number")
# 检查异常消息
with pytest.raises(ValueError, match="invalid literal"):
int("abc")
def test_warning():
"""警告断言"""
import warnings
with pytest.warns(UserWarning):
warnings.warn("test warning", UserWarning)
def test_custom_message():
"""自定义断言失败消息"""
value = 5
assert value > 10, f"值 {value} 应该大于10"
3.3 pytest fixtures¶
Python
import pytest
# 简单fixture
@pytest.fixture
def sample_data():
"""返回测试数据"""
return {"name": "test", "value": 42}
def test_with_fixture(sample_data):
"""使用fixture的测试"""
assert sample_data["name"] == "test"
assert sample_data["value"] == 42
# 带清理的fixture
@pytest.fixture
def temp_file():
"""创建临时文件并在测试后清理"""
import tempfile
import os
# setup
fd, path = tempfile.mkstemp()
os.write(fd, b"test content")
os.close(fd)
yield path # 返回给测试使用
# teardown (yield之后的代码)
os.unlink(path)
def test_temp_file(temp_file):
"""使用临时文件测试"""
with open(temp_file, 'r') as f:
content = f.read()
assert content == "test content"
# fixture作用域
@pytest.fixture(scope="module")
def database_connection():
"""模块级别fixture - 整个模块只执行一次"""
print("\n连接数据库...")
conn = {"connected": True}
yield conn
print("\n关闭数据库连接...")
conn["connected"] = False
def test_query_1(database_connection):
assert database_connection["connected"] is True
def test_query_2(database_connection):
assert database_connection["connected"] is True
3.4 参数化测试¶
Python
import pytest
# 单参数
@pytest.mark.parametrize("input,expected", [
("hello", "HELLO"),
("WORLD", "WORLD"),
("Python", "PYTHON"),
])
def test_upper_param(input, expected):
"""参数化测试 - 大写转换"""
assert input.upper() == expected
# 多参数
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(5, 5, 10),
(-1, 1, 0),
(100, -50, 50),
])
def test_add_param(a, b, expected):
"""参数化测试 - 加法"""
assert a + b == expected
# 组合参数化
@pytest.mark.parametrize("x", [1, 2])
@pytest.mark.parametrize("y", [10, 20])
def test_combination(x, y):
"""组合参数化 - 会产生4个测试"""
assert x + y in [11, 12, 21, 22]
# 使用ids命名
@pytest.mark.parametrize("input,expected", [
("hello", "HELLO"),
("world", "WORLD"),
], ids=["lowercase_hello", "lowercase_world"])
def test_with_ids(input, expected):
assert input.upper() == expected
3.5 标记和跳过¶
Python
import sys
import pytest
# 跳过测试
@pytest.mark.skip(reason="功能尚未实现")
def test_future_feature():
assert False
# 条件跳过
@pytest.mark.skipif(sys.version_info < (3, 8), reason="需要Python 3.8+")
def test_python38_feature():
# 使用Python 3.8+的特性
pass
# 标记预期失败
@pytest.mark.xfail(reason="已知bug,待修复")
def test_known_bug():
assert 1 == 2 # 这个测试预期会失败
# 自定义标记
@pytest.mark.slow
def test_slow_operation():
"""慢速测试"""
import time
time.sleep(2)
assert True
# 运行时选择
# pytest -m slow # 只运行slow标记的测试
# pytest -m "not slow" # 运行非slow的测试
4. Mock和测试替身¶
4.1 Mock基础¶
Python
from unittest.mock import Mock, MagicMock, patch
def test_mock_basic():
"""Mock基础用法"""
# 创建Mock对象
mock = Mock()
# 设置返回值
mock.method.return_value = 42
assert mock.method() == 42
# 验证调用
mock.method.assert_called_once()
def test_mock_attributes():
"""Mock属性"""
mock = Mock()
mock.name = "test"
mock.value = 100
assert mock.name == "test"
assert mock.value == 100
def test_mock_side_effect():
"""Mock副作用"""
mock = Mock()
# 使用函数作为副作用
mock.side_effect = lambda x: x * 2 # lambda匿名函数:简洁的单行函数
assert mock(5) == 10
assert mock(3) == 6
# 使用异常作为副作用
mock.side_effect = ValueError("error")
try: # try/except捕获异常
mock()
except ValueError as e:
assert str(e) == "error"
# 使用列表作为副作用(顺序返回)
mock.side_effect = [1, 2, 3]
assert mock() == 1
assert mock() == 2
assert mock() == 3
4.2 patch装饰器¶
Python
from unittest.mock import patch, Mock
import requests
# 被测试的函数
def get_user_info(user_id):
"""获取用户信息"""
response = requests.get(f"https://api.example.com/users/{user_id}")
return response.json()
# 使用patch模拟requests
@patch('requests.get')
def test_get_user_info(mock_get):
"""测试获取用户信息"""
# 设置模拟响应
mock_response = Mock()
mock_response.json.return_value = {"id": 1, "name": "Test User"}
mock_get.return_value = mock_response
# 调用被测函数
result = get_user_info(1)
# 验证结果
assert result == {"id": 1, "name": "Test User"}
mock_get.assert_called_once_with("https://api.example.com/users/1")
# 上下文管理器方式
def test_with_context_manager():
"""使用上下文管理器"""
with patch('requests.get') as mock_get:
mock_get.return_value.json.return_value = {"id": 2}
result = get_user_info(2)
assert result["id"] == 2
4.3 pytest-mock插件¶
Python
# pytest-mock 提供了mocker fixture
def test_with_mocker(mocker):
"""使用mocker fixture"""
# Mock整个类
mock_requests = mocker.patch('requests.get')
mock_requests.return_value.json.return_value = {"status": "ok"}
# Mock对象方法
obj = mocker.MagicMock()
obj.method.return_value = 42
# Spy - 部分Mock
# mocker.spy 实际调用会被执行,但可以验证调用情况
5. 测试覆盖率¶
5.1 使用coverage.py¶
Bash
# 安装
pip install coverage
# 运行测试并收集覆盖率
coverage run -m pytest
# 生成报告
coverage report
# 生成HTML报告
coverage html
# 查看报告
open htmlcov/index.html
5.2 pytest-cov插件¶
Bash
# 安装
pip install pytest-cov
# 运行测试并显示覆盖率
pytest --cov=myproject tests/
# 生成HTML报告
pytest --cov=myproject --cov-report=html tests/
5.3 配置覆盖率¶
INI
# .coveragerc 或 setup.cfg
[run]
source = myproject
omit =
*/tests/*
*/__pycache__/*
*/migrations/*
[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:
fail_under = 80
6. 高级测试技巧¶
6.1 测试私有方法¶
Python
class Calculator:
"""计算器类"""
def _validate_input(self, value):
"""私有方法 - 验证输入"""
if not isinstance(value, (int, float)): # isinstance检查对象类型
raise TypeError("必须是数字")
return True
def add(self, a, b):
"""公共方法"""
self._validate_input(a)
self._validate_input(b)
return a + b
# 测试方式1: 通过公共方法间接测试
def test_add_validates_input():
calc = Calculator()
with pytest.raises(TypeError):
calc.add("not a number", 1)
# 测试方式2: 直接访问私有方法(不推荐,但有时必要)
def test_private_method_directly():
calc = Calculator()
calc._validate_input(5) # 可以访问
with pytest.raises(TypeError):
calc._validate_input("invalid")
6.2 测试异常和边界¶
Python
import pytest
def divide(a, b):
"""除法函数"""
if b == 0:
raise ValueError("除数不能为零")
return a / b
class TestDivide:
"""除法测试类"""
def test_normal(self):
"""正常情况"""
assert divide(10, 2) == 5
assert divide(9, 3) == 3
def test_negative(self):
"""负数情况"""
assert divide(-10, 2) == -5
assert divide(10, -2) == -5
def test_float(self):
"""浮点数情况"""
assert divide(5, 2) == 2.5
def test_zero_divisor(self):
"""除数为零"""
with pytest.raises(ValueError, match="除数不能为零"):
divide(10, 0)
@pytest.mark.parametrize("a,b,expected", [
(0, 1, 0), # 被除数为零
(1, 1, 1), # 相同数
(1e10, 1, 1e10), # 大数
(1e-10, 1, 1e-10), # 小数
])
def test_edge_cases(self, a, b, expected):
"""边界情况"""
assert divide(a, b) == pytest.approx(expected)
6.3 测试文件系统¶
Python
import pytest
import tempfile
import os
@pytest.fixture
def temp_dir():
"""临时目录fixture"""
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
def test_file_operations(temp_dir):
"""测试文件操作"""
# 创建文件
file_path = os.path.join(temp_dir, "test.txt")
# 写入
with open(file_path, 'w') as f: # with自动管理资源,确保文件正确关闭
f.write("test content")
# 读取验证
with open(file_path, 'r') as f:
content = f.read()
assert content == "test content"
# 文件存在
assert os.path.exists(file_path)
6.4 测试数据库操作¶
Python
import pytest
import sqlite3
@pytest.fixture
def in_memory_db():
"""内存数据库fixture"""
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
# 创建测试表
cursor.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
)
''')
conn.commit()
yield conn # yield生成器:惰性产出值,节省内存
conn.close()
def test_insert_user(in_memory_db):
"""测试插入用户"""
cursor = in_memory_db.cursor()
cursor.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
("Test User", "test@example.com")
)
in_memory_db.commit()
cursor.execute("SELECT * FROM users")
users = cursor.fetchall()
assert len(users) == 1
assert users[0][1] == "Test User"
def test_unique_email(in_memory_db):
"""测试邮箱唯一性"""
cursor = in_memory_db.cursor()
cursor.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
("User1", "same@example.com")
)
in_memory_db.commit()
# 插入相同邮箱应该失败
with pytest.raises(sqlite3.IntegrityError):
cursor.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
("User2", "same@example.com")
)
7. 面试题精选¶
Q1: unittest和pytest有什么区别?¶
参考答案:
| 特性 | unittest | pytest |
|---|---|---|
| 断言 | self.assertEqual() | assert |
| 发现 | 需要继承TestCase | 自动发现test_*.py |
| Fixtures | setUp/tearDown | @pytest.fixture |
| 参数化 | 需要subTest | @pytest.mark.parametrize |
| 插件生态 | 较少 | 丰富(pytest-cov等) |
Q2: 什么是测试替身?有哪些类型?¶
参考答案: 测试替身是用于替代真实依赖的对象:
- Dummy:只用于填充参数,不实际使用
- Stub:提供预定义的响应
- Spy:记录调用信息用于验证
- Mock:验证交互行为
- Fake:有实际工作实现(如内存数据库)
Q3: 如何测试异步代码?¶
参考答案:
Python
import pytest
@pytest.mark.asyncio
async def test_async_function(): # async def定义异步函数;用await调用
result = await async_operation() # await等待异步操作完成
assert result == expected_value
# 需要安装 pytest-asyncio
8. 最佳实践¶
8.1 测试命名规范¶
Python
# 好的命名
def test_add_should_return_sum_of_two_numbers():
"""清晰描述测试意图"""
pass
def test_divide_should_raise_error_when_divisor_is_zero():
"""包含预期行为和条件"""
pass
# 不好的命名
def test_add(): # 太笼统
pass
def test1(): # 无意义
pass
8.2 AAA模式¶
Python
def test_with_aaa_pattern():
"""Arrange-Act-Assert 模式"""
# Arrange (准备)
calculator = Calculator()
a, b = 5, 3
# Act (执行)
result = calculator.add(a, b)
# Assert (断言)
assert result == 8 # assert断言:条件为False时抛出AssertionError
8.3 测试清单¶
- 每个公共方法都有对应测试?
- 测试覆盖正常和异常路径?
- 测试之间相互独立?
- 使用有意义的测试名称?
- Mock了外部依赖?
- 测试覆盖率达标(≥80%)?
9. 学习检查清单¶
完成本章学习后,请确认你能够:
- 使用unittest编写基本测试
- 使用pytest编写简洁的测试
- 编写和使用fixtures
- 使用参数化测试减少重复代码
- 使用Mock模拟外部依赖
- 测量和提高测试覆盖率
- 遵循测试最佳实践
10. 参考资料¶
官方文档¶
相关教程¶
最后更新日期:2026-02-17 适用版本:测试与质量保证教程 v2026