跳转至

08-事务与并发控制

事务与并发控制

📋 本章概览

事务是数据库的核心特性之一,确保数据的一致性和完整性。本章将深入理解ACID特性、隔离级别,以及在高并发场景下的数据一致性保障机制,特别针对AI/ML场景中的并发训练、数据流水线等应用。

学习目标: - 理解ACID特性及其重要性 - 掌握事务的提交和回滚机制 - 理解不同隔离级别及其影响 - 学会处理并发问题和死锁 - 掌握AI场景下的事务设计模式

预计学习时间: 5-7小时

前置章节: 第07章:数据库优化与调优


1. 事务基础

1.1 什么是事务

事务(Transaction)是一组数据库操作的逻辑单元,这些操作要么全部成功执行,要么全部不执行。

Text Only
┌─────────────────────────────────────────────────────────────┐
│                     事务的生命周期                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   BEGIN ──→ 执行操作1 ──→ 执行操作2 ──→ 执行操作3          │
│              ↓                                                    │
│         ┌────┴────┐                                         │
│         ↓         ↓                                         │
│      COMMIT    ROLLBACK                                     │
│     (全部成功)  (全部撤销)                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 ACID特性

Python
"""
ACID特性详解

A - Atomicity(原子性)
C - Consistency(一致性)
I - Isolation(隔离性)
D - Durability(持久性)
"""

class ACIDExplanation:
    """ACID特性解释器"""

    def atomicity(self):
        """原子性:事务是不可分割的最小工作单元"""
        example = """
        银行转账示例:
        1. 从账户A扣除100元
        2. 向账户B增加100元

        原子性保证:两个操作要么都成功,要么都失败
        不会出现A扣款成功但B未到账的情况
        """
        return example

    def consistency(self):
        """一致性:事务执行前后,数据库处于一致状态"""
        example = """
        一致性约束示例:
        - 转账前后,两个账户的总金额不变
        - 库存不能为负数
        - 外键约束必须满足
        """
        return example

    def isolation(self):
        """隔离性:并发事务之间相互隔离"""
        example = """
        隔离性级别:
        - READ UNCOMMITTED:最低隔离,可能读到脏数据
        - READ COMMITTED:避免脏读
        - REPEATABLE READ:避免不可重复读(MySQL默认)
        - SERIALIZABLE:最高隔离,完全串行化
        """
        return example

    def durability(self):
        """持久性:事务一旦提交,数据永久保存"""
        example = """
        持久性保证:
        - 即使系统崩溃,已提交的数据不会丢失
        - 通过WAL(Write-Ahead Logging)实现
        - 数据写入磁盘后才返回成功
        """
        return example

1.3 基本事务操作

SQL
-- MySQL事务示例
-- 开启事务
START TRANSACTION;
-- 或
BEGIN;

-- 执行操作
UPDATE accounts SET balance = balance - 100 WHERE id = 1;
UPDATE accounts SET balance = balance + 100 WHERE id = 2;

-- 提交事务(永久保存)
COMMIT;

-- 或回滚事务(撤销所有操作)
ROLLBACK;

-- PostgreSQL事务示例
BEGIN;

INSERT INTO experiments (name, status) VALUES ('exp_001', 'running');
UPDATE model_configs SET last_used = NOW() WHERE model_id = 'm001';

-- 保存点(部分回滚)
SAVEPOINT before_validation;

-- 如果验证失败,回滚到保存点
ROLLBACK TO SAVEPOINT before_validation;

-- 最终提交
COMMIT;
Python
# Python中的事务操作
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

engine = create_engine('mysql+pymysql://user:pass@localhost/dbname')
Session = sessionmaker(bind=engine)

# 方式1:手动管理事务
session = Session()
try:
    # 执行数据库操作
    user = User(name='张三', balance=1000)
    session.add(user)

    order = Order(user_id=user.id, amount=100)
    session.add(order)

    # 提交事务
    session.commit()
    print("事务提交成功")

except Exception as e:
    # 发生错误,回滚事务
    session.rollback()
    print(f"事务回滚: {e}")

finally:
    session.close()

# 方式2:使用上下文管理器(推荐)
with Session() as session:
    try:
        user = User(name='李四', balance=2000)
        session.add(user)
        session.commit()
    except Exception as e:
        session.rollback()
        raise
    # 自动关闭session

2. 并发问题

2.1 常见的并发异常

Text Only
┌─────────────────────────────────────────────────────────────┐
│                     并发异常类型                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 脏读(Dirty Read)                                      │
│     事务A读取了事务B未提交的数据                            │
│                                                             │
│     时间线:                                                │
│     T1: 事务B修改数据(未提交)                             │
│     T2: 事务A读取了修改后的数据                             │
│     T3: 事务B回滚                                           │
│     T4: 事务A使用了"不存在"的数据                           │
│                                                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  2. 不可重复读(Non-repeatable Read)                       │
│     同一事务内,两次读取同一数据结果不同                    │
│                                                             │
│     时间线:                                                │
│     T1: 事务A读取数据 = 100                                 │
│     T2: 事务B修改数据为200并提交                            │
│     T3: 事务A再次读取数据 = 200(不一致!)                 │
│                                                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  3. 幻读(Phantom Read)                                    │
│     同一事务内,两次查询返回的行数不同                      │
│                                                             │
│     时间线:                                                │
│     T1: 事务A查询 WHERE age > 20,返回10行                  │
│     T2: 事务B插入一条 age=25 的记录并提交                   │
│     T3: 事务A再次查询,返回11行(出现"幻影"行)             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 隔离级别详解

SQL
-- 查看当前隔离级别
-- MySQL
SELECT @@transaction_isolation;

-- PostgreSQL
SHOW transaction_isolation;

-- 设置隔离级别
-- MySQL
SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED;

-- PostgreSQL
SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;
隔离级别 脏读 不可重复读 幻读 性能
READ UNCOMMITTED 最高
READ COMMITTED
REPEATABLE READ ✗(InnoDB) / ✓(标准SQL)
SERIALIZABLE 最低

⚠️ 重要说明:标准SQL vs InnoDB实现的差异

标准SQL定义:REPEATABLE READ 隔离级别允许幻读发生。

MySQL InnoDB实现:通过 Next-Key Locking 机制(间隙锁Gap Lock + 记录锁Record Lock的组合),在REPEATABLE READ级别下可以防止幻读。这是InnoDB对标准SQL的增强实现。

PostgreSQL实现:通过 SSI(Serializable Snapshot Isolation) 机制在REPEATABLE READ级别下也能防止幻读。

上表中"✗(InnoDB)"表示在MySQL InnoDB的实际实现中不会发生幻读,这是大多数开发者使用MySQL时的实际体验。

Python
# Python中设置隔离级别
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

# MySQL默认REPEATABLE READ
engine_mysql = create_engine(
    'mysql+pymysql://user:pass@localhost/dbname',
    isolation_level='REPEATABLE_READ'  # 可选:READ_UNCOMMITTED, READ_COMMITTED, SERIALIZABLE
)

# PostgreSQL默认READ COMMITTED
engine_pg = create_engine(
    'postgresql+psycopg2://user:pass@localhost/dbname',
    isolation_level='READ_COMMITTED'  # 可选:SERIALIZABLE, REPEATABLE_READ
)

# 在会话级别设置
with engine.connect() as conn:
    conn = conn.execution_options(isolation_level='SERIALIZABLE')
    # 执行操作...

2.3 隔离级别实战演示

Python
"""
隔离级别演示代码
需要两个并发会话来观察效果
"""
import threading  # 线程池/多线程:并发执行任务
import time
from sqlalchemy import create_engine, text

def demonstrate_dirty_read():
    """演示脏读问题"""
    engine = create_engine('mysql+pymysql://user:pass@localhost/dbname')

    def session_a():
        """会话A:读取数据"""
        with engine.connect() as conn:
            conn = conn.execution_options(isolation_level='READ_UNCOMMITTED')

            # 读取未提交的数据(脏读)
            result = conn.execute(text("SELECT balance FROM accounts WHERE id = 1"))
            balance = result.scalar()
            print(f"会话A读取到: {balance}")  # 可能是未提交的值

    def session_b():
        """会话B:修改但不提交"""
        with engine.begin() as conn:
            conn.execute(text("UPDATE accounts SET balance = 999 WHERE id = 1"))
            print("会话B修改了数据,但未提交")
            time.sleep(3)  # 保持事务打开
            # 事务结束时会自动回滚(使用begin()上下文)

    # 启动两个线程
    t1 = threading.Thread(target=session_b)
    t2 = threading.Thread(target=session_a)

    t1.start()
    time.sleep(0.5)  # 确保B先执行
    t2.start()

    t1.join()
    t2.join()

def demonstrate_non_repeatable_read():
    """演示不可重复读问题"""
    engine = create_engine('mysql+pymysql://user:pass@localhost/dbname')

    def session_a():
        """会话A:两次读取"""
        with engine.begin() as conn:
            conn = conn.execution_options(isolation_level='READ_COMMITTED')

            # 第一次读取
            result = conn.execute(text("SELECT balance FROM accounts WHERE id = 1"))
            balance1 = result.scalar()
            print(f"会话A第一次读取: {balance1}")

            time.sleep(2)  # 等待会话B修改

            # 第二次读取
            result = conn.execute(text("SELECT balance FROM accounts WHERE id = 1"))
            balance2 = result.scalar()
            print(f"会话A第二次读取: {balance2}")

            if balance1 != balance2:
                print("出现不可重复读!")

    def session_b():
        """会话B:修改并提交"""
        time.sleep(1)  # 等待A第一次读取
        with engine.begin() as conn:
            conn.execute(text("UPDATE accounts SET balance = balance + 100 WHERE id = 1"))
            print("会话B修改并提交")

    t1 = threading.Thread(target=session_a)
    t2 = threading.Thread(target=session_b)

    t1.start()
    t2.start()
    t1.join()
    t2.join()

3. 锁机制

3.1 锁的类型

Text Only
┌─────────────────────────────────────────────────────────────┐
│                       锁的分类                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  按粒度分类:                                               │
│  ├─ 行级锁(Row Lock):锁定单行,并发度高                  │
│  ├─ 页级锁(Page Lock):锁定数据页                         │
│  └─ 表级锁(Table Lock):锁定整个表,并发度低              │
│                                                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  按模式分类:                                               │
│  ├─ 共享锁(S锁/读锁):多个事务可同时持有                  │
│  │   SELECT ... LOCK IN SHARE MODE                          │
│  │                                                         │
│  └─ 排他锁(X锁/写锁):只有一个事务可持有                  │
│      SELECT ... FOR UPDATE                                  │
│      INSERT/UPDATE/DELETE                                   │
│                                                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  按使用方式分类:                                           │
│  ├─ 乐观锁:通过版本号控制,提交时检查                      │
│  └─ 悲观锁:先加锁,再操作                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3.2 悲观锁实现

SQL
-- MySQL悲观锁示例

-- 1. 共享锁(读锁)
BEGIN;  -- 事务保证操作原子性
SELECT * FROM products WHERE id = 1 LOCK IN SHARE MODE;
-- 其他事务可以读,但不能修改
COMMIT;

-- 2. 排他锁(写锁)
BEGIN;
SELECT * FROM products WHERE id = 1 FOR UPDATE;
-- 其他事务的普通SELECT仍可读(MVCC快照读),但不能加锁读或修改,直到锁释放
UPDATE products SET stock = stock - 1 WHERE id = 1;
COMMIT;

-- 3. 锁等待超时设置
SET innodb_lock_wait_timeout = 50;  -- 50秒后超时
Python
# Python悲观锁实现
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

engine = create_engine('mysql+pymysql://user:pass@localhost/dbname')
Session = sessionmaker(bind=engine)

def pessimistic_lock_demo():
    """悲观锁演示:库存扣减"""
    session = Session()

    try:
        # 1. 加排他锁并读取
        result = session.execute(
            text("SELECT * FROM products WHERE id = 1 FOR UPDATE")
        )
        product = result.fetchone()

        if product.stock > 0:
            # 2. 修改库存
            session.execute(
                text("UPDATE products SET stock = stock - 1 WHERE id = 1")
            )

            # 3. 创建订单
            session.execute(
                text("INSERT INTO orders (product_id, quantity) VALUES (1, 1)")
            )

            session.commit()
            print("订单创建成功")
        else:
            print("库存不足")
            session.rollback()

    except Exception as e:
        session.rollback()
        print(f"操作失败: {e}")
    finally:
        session.close()

# 使用SQLAlchemy ORM的with_for_update
from sqlalchemy.orm import joinedload

def orm_pessimistic_lock():
    """使用ORM的悲观锁"""
    session = Session()

    try:
        # 加锁查询
        product = session.query(Product).filter(
            Product.id == 1
        ).with_for_update().first()

        if product and product.stock > 0:
            product.stock -= 1

            order = Order(product_id=1, quantity=1)
            session.add(order)

            session.commit()
            print("订单创建成功")
        else:
            print("库存不足")

    except Exception as e:
        session.rollback()
        raise
    finally:
        session.close()

3.3 乐观锁实现

SQL
-- 乐观锁表设计
CREATE TABLE products (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    stock INT NOT NULL,
    version INT DEFAULT 0,  -- 版本号
    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);

-- 乐观锁更新逻辑
-- 1. 读取数据时获取版本号
SELECT id, stock, version FROM products WHERE id = 1;
-- 结果:id=1, stock=100, version=5

-- 2. 更新时检查版本号
UPDATE products
SET stock = stock - 1, version = version + 1
WHERE id = 1 AND version = 5;

-- 如果返回影响行数为0,说明版本已变,需要重试
Python
# Python乐观锁实现
from sqlalchemy import create_engine, Column, Integer, String, DateTime
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql import func

Base = declarative_base()

class Product(Base):
    __tablename__ = 'products'

    id = Column(Integer, primary_key=True)
    name = Column(String(100))
    stock = Column(Integer, nullable=False)
    version = Column(Integer, default=0)  # 乐观锁版本号
    updated_at = Column(DateTime, default=func.now(), onupdate=func.now())

engine = create_engine('mysql+pymysql://user:pass@localhost/dbname')
Session = sessionmaker(bind=engine)

class OptimisticLock:
    """乐观锁管理器"""

    MAX_RETRIES = 3

    def deduct_stock(self, product_id: int, quantity: int = 1):
        """
        扣减库存(带乐观锁重试)

        Args:
            product_id: 商品ID
            quantity: 扣减数量

        Returns:
            bool: 是否成功
        """
        for attempt in range(self.MAX_RETRIES):
            session = Session()

            try:
                # 1. 读取商品信息和版本号
                product = session.query(Product).filter(
                    Product.id == product_id
                ).first()

                if not product or product.stock < quantity:
                    return False

                current_version = product.version

                # 2. 尝试更新(带版本号检查)
                result = session.query(Product).filter(
                    Product.id == product_id,
                    Product.version == current_version
                ).update({
                    'stock': Product.stock - quantity,
                    'version': Product.version + 1
                })

                session.commit()

                # 3. 检查更新是否成功
                if result == 1:
                    print(f"扣减成功,当前版本: {current_version + 1}")
                    return True
                else:
                    print(f"版本冲突,第{attempt + 1}次重试...")
                    time.sleep(0.1 * (attempt + 1))  # 指数退避

            except Exception as e:
                session.rollback()
                print(f"操作异常: {e}")
                return False
            finally:
                session.close()

        print("超过最大重试次数")
        return False

# 使用示例
lock_manager = OptimisticLock()
success = lock_manager.deduct_stock(product_id=1, quantity=1)

3.4 锁的选择策略

Python
"""
锁的选择决策树
"""

def choose_lock_strategy(scenario):
    """
    根据场景选择锁策略

    Args:
        scenario: 包含以下字段的字典
            - read_ratio: 读操作比例 (0-1)
            - conflict_probability: 冲突概率 (0-1)
            - operation_duration: 操作持续时间 (短/中/长)
            - consistency_requirement: 一致性要求 (高/中/低)
    """
    read_ratio = scenario.get('read_ratio', 0.5)
    conflict_prob = scenario.get('conflict_probability', 0.1)
    duration = scenario.get('operation_duration', '短')
    consistency = scenario.get('consistency_requirement', '中')

    # 决策逻辑
    if conflict_prob < 0.01 and read_ratio > 0.9:
        return {
            'strategy': '乐观锁',
            'reason': '读多写少,冲突概率极低,乐观锁性能最好',
            'implementation': '版本号机制'
        }

    elif conflict_prob > 0.3 or consistency == '高':
        return {
            'strategy': '悲观锁',
            'reason': '冲突概率高或一致性要求高,悲观锁更安全',
            'implementation': 'SELECT FOR UPDATE'
        }

    elif duration == '长':
        return {
            'strategy': '乐观锁',
            'reason': '操作时间长,悲观锁会阻塞其他事务太久',
            'implementation': '版本号 + 重试机制'
        }

    else:
        return {
            'strategy': '悲观锁',
            'reason': '一般场景,悲观锁简单可靠',
            'implementation': '行级锁'
        }

# 场景示例
scenarios = [
    {
        'name': '商品库存扣减',
        'read_ratio': 0.8,
        'conflict_probability': 0.05,
        'operation_duration': '短',
        'consistency_requirement': '高'
    },
    {
        'name': '模型训练状态更新',
        'read_ratio': 0.3,
        'conflict_probability': 0.001,
        'operation_duration': '长',
        'consistency_requirement': '中'
    },
    {
        'name': '用户余额转账',
        'read_ratio': 0.2,
        'conflict_probability': 0.4,
        'operation_duration': '短',
        'consistency_requirement': '高'
    }
]

for scenario in scenarios:
    result = choose_lock_strategy(scenario)
    print(f"\n场景: {scenario['name']}")
    print(f"推荐策略: {result['strategy']}")
    print(f"原因: {result['reason']}")
    print(f"实现方式: {result['implementation']}")

4. 死锁处理

4.1 死锁的产生

Text Only
┌─────────────────────────────────────────────────────────────┐
│                     死锁示例                                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  事务A                                      事务B           │
│    │                                         │             │
│    ▼                                         ▼             │
│  锁定记录1                                锁定记录2         │
│    │                                         │             │
│    ▼                                         ▼             │
│  请求锁定记录2 ◄────── 等待 ──────► 请求锁定记录1           │
│    │                                         │             │
│    └─────────────── 死锁! ──────────────────┘             │
│                                                             │
│  结果:数据库选择一个事务作为牺牲者,回滚该事务              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4.2 死锁检测与处理

Python
# 死锁检测和处理
import time
import random
from sqlalchemy import create_engine, text
from sqlalchemy.exc import OperationalError

engine = create_engine('mysql+pymysql://user:pass@localhost/dbname')

def safe_transfer(from_id: int, to_id: int, amount: float, max_retries: int = 3):
    """
    安全的转账操作(避免死锁)

    策略:
    1. 总是按照固定顺序获取锁
    2. 设置锁等待超时
    3. 死锁时自动重试
    """

    # 确保按ID顺序获取锁,避免循环等待
    first_id, second_id = sorted([from_id, to_id])

    for attempt in range(max_retries):
        try:
            with engine.begin() as conn:
                # 设置锁等待超时
                conn.execute(text("SET innodb_lock_wait_timeout = 5"))

                # 按固定顺序加锁
                conn.execute(
                    text("SELECT * FROM accounts WHERE id = :id FOR UPDATE"),
                    {"id": first_id}
                )

                conn.execute(
                    text("SELECT * FROM accounts WHERE id = :id FOR UPDATE"),
                    {"id": second_id}
                )

                # 检查余额
                result = conn.execute(
                    text("SELECT balance FROM accounts WHERE id = :id"),
                    {"id": from_id}
                )
                balance = result.scalar()

                if balance < amount:
                    raise ValueError(f"余额不足: {balance} < {amount}")

                # 执行转账
                conn.execute(
                    text("UPDATE accounts SET balance = balance - :amount WHERE id = :id"),
                    {"amount": amount, "id": from_id}
                )

                conn.execute(
                    text("UPDATE accounts SET balance = balance + :amount WHERE id = :id"),
                    {"amount": amount, "id": to_id}
                )

                print(f"转账成功: {from_id} -> {to_id}, 金额: {amount}")
                return True

        except OperationalError as e:
            if "Deadlock found" in str(e) or "Lock wait timeout" in str(e):
                print(f"检测到死锁/锁超时,第{attempt + 1}次重试...")
                time.sleep(random.uniform(0.1, 0.5) * (attempt + 1))
            else:
                raise
        except Exception as e:
            print(f"转账失败: {e}")
            return False

    print("超过最大重试次数,转账失败")
    return False

# 死锁监控查询(MySQL)
DEADLOCK_MONITOR_SQL = """
-- 查看最近的死锁信息
SHOW ENGINE INNODB STATUS;

-- 查看当前锁等待(MySQL 8.0+使用performance_schema)
SELECT
    dl.OBJECT_NAME AS locked_table,
    dl.LOCK_TYPE,
    dl.LOCK_MODE,
    dl.LOCK_STATUS,
    dl.OWNER_THREAD_ID AS blocking_thread_id,
    t.PROCESSLIST_ID AS blocking_process_id,
    t.PROCESSLIST_INFO AS blocking_query
FROM performance_schema.data_locks dl
INNER JOIN performance_schema.threads t ON dl.OWNER_THREAD_ID = t.THREAD_ID
WHERE dl.LOCK_STATUS = 'GRANTED'
ORDER BY dl.OBJECT_NAME;

-- 查看锁等待超时配置
SHOW VARIABLES LIKE 'innodb_lock_wait_timeout';
"""

4.3 死锁预防策略

Python
"""
死锁预防最佳实践
"""

class DeadlockPrevention:
    """死锁预防策略"""

    @staticmethod  # @staticmethod静态方法,不需要实例
    def consistent_ordering(resources):
        """
        策略1:资源排序法
        所有事务按相同顺序访问资源
        """
        return sorted(resources)

    @staticmethod
    def timeout_based(resources, timeout=5):
        """
        策略2:超时放弃法
        设置锁等待超时,超时后放弃并重试
        """
        return {
            'resources': resources,
            'timeout': timeout,
            'retry_strategy': 'exponential_backoff'
        }

    @staticmethod
    def one_shot_allocation(all_resources_needed):
        """
        策略3:一次性分配
        事务开始时一次性获取所有需要的锁
        """
        return {
            'resources': sorted(all_resources_needed),
            'allocation_strategy': 'all_or_nothing'
        }

    @staticmethod
    def optimistic_with_validation(resources):
        """
        策略4:乐观锁
        不阻塞,提交时验证
        """
        return {
            'strategy': 'optimistic',
            'validation': 'commit_time',
            'retry_on_conflict': True
        }

# 实际应用示例
def batch_update_with_deadlock_prevention(updates):
    """
    批量更新(防死锁版本)

    Args:
        updates: [(id, new_value), ...]
    """
    # 按ID排序,确保所有事务获取锁的顺序一致
    sorted_updates = sorted(updates, key=lambda x: x[0])  # lambda匿名函数:简洁的单行函数

    with engine.begin() as conn:
        for id_, value in sorted_updates:
            conn.execute(
                text("UPDATE records SET value = :value WHERE id = :id"),
                {"value": value, "id": id_}
            )

5. AI场景下的事务设计

5.1 模型训练事务管理

Python
"""
AI模型训练的事务管理
场景:训练过程中需要记录多个相关状态
"""

from sqlalchemy import create_engine, Column, Integer, String, Float, JSON, DateTime
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql import func
import time

Base = declarative_base()

class TrainingJob(Base):
    __tablename__ = 'training_jobs'

    id = Column(Integer, primary_key=True)
    model_name = Column(String(100))
    status = Column(String(20), default='pending')  # pending/running/completed/failed
    hyperparameters = Column(JSON)
    metrics = Column(JSON)
    checkpoint_path = Column(String(255))
    created_at = Column(DateTime, default=func.now())
    updated_at = Column(DateTime, default=func.now(), onupdate=func.now())

class TrainingEpoch(Base):
    __tablename__ = 'training_epochs'

    id = Column(Integer, primary_key=True)
    job_id = Column(Integer)
    epoch_number = Column(Integer)
    loss = Column(Float)
    accuracy = Column(Float)
    val_loss = Column(Float)
    val_accuracy = Column(Float)
    created_at = Column(DateTime, default=func.now())

engine = create_engine('postgresql+psycopg2://user:pass@localhost/dbname')
Session = sessionmaker(bind=engine)

class TrainingTransactionManager:
    """训练事务管理器"""

    def __init__(self):
        self.session = Session()

    def start_training(self, model_name: str, hyperparameters: dict):
        """开始训练(创建训练任务)"""
        try:
            job = TrainingJob(
                model_name=model_name,
                status='running',
                hyperparameters=hyperparameters,
                metrics={}
            )
            self.session.add(job)
            self.session.commit()
            return job.id
        except Exception as e:
            self.session.rollback()
            raise

    def log_epoch(self, job_id: int, epoch: int, metrics: dict):
        """
        记录训练轮次

        注意:这是独立事务,即使记录失败也不影响训练
        """
        try:
            epoch_record = TrainingEpoch(
                job_id=job_id,
                epoch_number=epoch,
                loss=metrics.get('loss'),
                accuracy=metrics.get('accuracy'),
                val_loss=metrics.get('val_loss'),
                val_accuracy=metrics.get('val_accuracy')
            )
            self.session.add(epoch_record)
            self.session.commit()
        except Exception as e:
            self.session.rollback()
            # 记录失败不影响训练,只打印日志
            print(f"记录epoch失败: {e}")

    def complete_training(self, job_id: int, final_metrics: dict, checkpoint_path: str):
        """完成训练(原子性更新)"""
        try:
            job = self.session.query(TrainingJob).filter(
                TrainingJob.id == job_id
            ).with_for_update().first()

            if job:
                job.status = 'completed'
                job.metrics = final_metrics
                job.checkpoint_path = checkpoint_path
                self.session.commit()
                print(f"训练任务 {job_id} 完成")
            else:
                raise ValueError(f"训练任务 {job_id} 不存在")

        except Exception as e:
            self.session.rollback()
            # 标记为失败
            self._mark_failed(job_id, str(e))
            raise

    def _mark_failed(self, job_id: int, error_msg: str):
        """标记训练失败"""
        try:
            job = self.session.query(TrainingJob).filter(
                TrainingJob.id == job_id
            ).first()
            if job:
                job.status = 'failed'
                job.metrics = {'error': error_msg}
                self.session.commit()
        except:
            self.session.rollback()

    def close(self):
        self.session.close()

# 使用示例
def train_model_with_transaction():
    """带事务管理的模型训练"""
    manager = TrainingTransactionManager()

    try:
        # 1. 创建训练任务
        job_id = manager.start_training(
            model_name='resnet50',
            hyperparameters={'lr': 0.001, 'batch_size': 32}
        )
        print(f"训练任务创建: {job_id}")

        # 2. 模拟训练过程
        for epoch in range(10):
            # 模拟训练...
            time.sleep(0.1)

            # 记录每轮结果
            metrics = {
                'loss': 1.0 / (epoch + 1),
                'accuracy': 0.5 + epoch * 0.05,
                'val_loss': 1.2 / (epoch + 1),
                'val_accuracy': 0.45 + epoch * 0.04
            }
            manager.log_epoch(job_id, epoch, metrics)

        # 3. 完成训练
        manager.complete_training(
            job_id=job_id,
            final_metrics={'final_accuracy': 0.95},
            checkpoint_path=f'/models/job_{job_id}/model.pth'
        )

    except Exception as e:
        print(f"训练失败: {e}")
    finally:
        manager.close()

5.2 数据流水线事务

Python
"""
数据ETL流水线的事务管理
确保数据处理的完整性和一致性
"""

class DataPipelineTransaction:
    """数据流水线事务管理

    注意:target_table/source_query 等参数应来自受信内部配置,
    不可直接接受用户输入,以防止SQL注入。
    """

    def __init__(self, engine):
        self.engine = engine

    def etl_with_transaction(self, source_query: str, target_table: str, transform_func):
        """
        ETL操作(使用事务保证数据一致性)

        策略:
        1. 创建临时表
        2. 在临时表中处理数据
        3. 验证数据质量
        4. 原子性替换目标表
        """
        temp_table = f"{target_table}_temp"
        backup_table = f"{target_table}_backup"

        with self.engine.begin() as conn:
            try:
                # 1. 创建临时表
                conn.execute(text(f"""
                    CREATE TABLE {temp_table} LIKE {target_table}
                """))

                # 2. 提取和转换数据
                result = conn.execute(text(source_query))
                rows = result.fetchall()

                # 3. 批量插入临时表
                transformed_data = [transform_func(row) for row in rows]

                if transformed_data:
                    # SQLAlchemy text() 使用 :named 参数风格
                    col_count = len(transformed_data[0])
                    placeholders = ', '.join([f':p{i}' for i in range(col_count)])
                    stmt = text(f"INSERT INTO {temp_table} VALUES ({placeholders})")
                    conn.execute(stmt, [
                        {f'p{i}': v for i, v in enumerate(row)}  # enumerate同时获取索引和值
                        for row in transformed_data
                    ])

                # 4. 数据质量检查
                validation_result = self._validate_data(conn, temp_table)
                if not validation_result['valid']:
                    raise ValueError(f"数据验证失败: {validation_result['errors']}")

                # 5. 原子性替换(重命名表)
                conn.execute(text(f"RENAME TABLE {target_table} TO {backup_table}"))
                conn.execute(text(f"RENAME TABLE {temp_table} TO {target_table}"))
                conn.execute(text(f"DROP TABLE {backup_table}"))

                print(f"ETL完成:处理了 {len(rows)} 条记录")

            except Exception as e:
                # 清理临时表
                conn.execute(text(f"DROP TABLE IF EXISTS {temp_table}"))
                raise

    def _validate_data(self, conn, table_name: str):
        """数据质量验证"""
        errors = []

        # 检查空值率
        result = conn.execute(text(f"""
            SELECT
                COUNT(*) as total,
                SUM(CASE WHEN feature_1 IS NULL THEN 1 ELSE 0 END) as null_count
            FROM {table_name}
        """))
        row = result.fetchone()
        null_rate = row.null_count / row.total if row.total > 0 else 0

        if null_rate > 0.1:  # 空值率超过10%
            errors.append(f"空值率过高: {null_rate:.2%}")

        # 检查数据范围
        result = conn.execute(text(f"""
            SELECT MIN(feature_1), MAX(feature_1) FROM {table_name}
        """))
        min_val, max_val = result.fetchone()

        if min_val < -1000 or max_val > 1000:
            errors.append(f"数据范围异常: [{min_val}, {max_val}]")

        return {
            'valid': len(errors) == 0,
            'errors': errors
        }

    def incremental_load(self, target_table: str, new_data: list, watermark_column: str):
        """
        增量加载(使用水印机制)

        策略:
        1. 获取上次加载的水印
        2. 只加载新数据
        3. 更新水印
        """
        with self.engine.begin() as conn:
            # 获取当前水印
            result = conn.execute(text("""
                SELECT MAX(watermark_value)
                FROM etl_watermarks
                WHERE table_name = :table
            """), {"table": target_table})
            last_watermark = result.scalar() or '1970-01-01'

            # 过滤新数据
            new_records = [
                record for record in new_data
                if record[watermark_column] > last_watermark
            ]

            if new_records:
                # 插入新数据(SQLAlchemy text() 使用 :named 参数风格)
                col_count = len(new_records[0])
                placeholders = ', '.join([f':p{i}' for i in range(col_count)])
                stmt = text(f"INSERT INTO {target_table} VALUES ({placeholders})")
                conn.execute(stmt, [
                    {f'p{i}': v for i, v in enumerate(record)}
                    for record in new_records
                ])

                # 更新水印
                new_watermark = max(r[watermark_column] for r in new_records)
                conn.execute(text("""
                    INSERT INTO etl_watermarks (table_name, watermark_value, updated_at)
                    VALUES (:table, :watermark, NOW())
                    ON DUPLICATE KEY UPDATE
                        watermark_value = VALUES(watermark_value),
                        updated_at = NOW()
                """), {"table": target_table, "watermark": new_watermark})

                print(f"增量加载完成:{len(new_records)} 条新记录")

5.3 分布式训练协调

Python
"""
分布式训练的数据库协调
使用数据库作为分布式锁和状态同步中心
"""

import uuid
from datetime import datetime, timedelta

class DistributedTrainingCoordinator:
    """分布式训练协调器"""

    def __init__(self, engine):
        self.engine = engine
        self.worker_id = str(uuid.uuid4())

    def acquire_task(self, task_type: str, timeout_seconds: int = 300):
        """
        获取分布式任务(使用数据库实现分布式锁)

        Returns:
            task_id 或 None
        """
        with self.engine.begin() as conn:
            # 使用SELECT FOR UPDATE实现分布式锁
            result = conn.execute(text("""
                SELECT task_id, task_data
                FROM distributed_tasks
                WHERE task_type = :type
                  AND status = 'pending'
                  AND (locked_by IS NULL OR locked_at < NOW() - INTERVAL '1 second' * :timeout)
                ORDER BY priority DESC, created_at ASC
                LIMIT 1
                FOR UPDATE SKIP LOCKED
            """), {
                "type": task_type,
                "timeout": timeout_seconds
            })

            task = result.fetchone()

            if task:
                # 锁定任务
                conn.execute(text("""
                    UPDATE distributed_tasks
                    SET locked_by = :worker,
                        locked_at = NOW(),
                        status = 'running'
                    WHERE task_id = :task_id
                """), {"worker": self.worker_id, "task_id": task.task_id})

                return {
                    'task_id': task.task_id,
                    'task_data': task.task_data
                }

            return None

    def complete_task(self, task_id: str, result_data: dict):
        """完成任务"""
        with self.engine.begin() as conn:
            conn.execute(text("""
                UPDATE distributed_tasks
                SET status = 'completed',
                    result_data = :result,
                    completed_at = NOW(),
                    locked_by = NULL
                WHERE task_id = :task_id
                  AND locked_by = :worker
            """), {
                "result": json.dumps(result_data),  # json.dumps将Python对象转为JSON字符串
                "task_id": task_id,
                "worker": self.worker_id
            })

    def report_progress(self, task_id: str, progress: float, metrics: dict):
        """报告进度(非事务性,允许失败)"""
        try:  # try/except捕获异常
            with self.engine.connect() as conn:
                conn.execute(text("""
                    UPDATE distributed_tasks
                    SET progress = :progress,
                        metrics = :metrics,
                        last_heartbeat = NOW()
                    WHERE task_id = :task_id
                """), {
                    "progress": progress,
                    "metrics": json.dumps(metrics),
                    "task_id": task_id
                })
                conn.commit()
        except Exception as e:
            # 进度报告失败不影响训练
            print(f"进度报告失败: {e}")

    def aggregate_results(self, experiment_id: str):
        """
        聚合分布式训练结果

        使用事务确保聚合的准确性
        """
        with self.engine.begin() as conn:
            # 获取所有完成的任务
            result = conn.execute(text("""
                SELECT result_data
                FROM distributed_tasks
                WHERE experiment_id = :exp_id
                  AND status = 'completed'
                FOR UPDATE
            """), {"exp_id": experiment_id})

            results = [json.loads(row.result_data) for row in result]  # json.loads将JSON字符串转为Python对象

            if not results:
                return None

            # 聚合指标
            aggregated = {
                'experiment_id': experiment_id,
                'total_tasks': len(results),
                'avg_accuracy': sum(r['accuracy'] for r in results) / len(results),
                'avg_loss': sum(r['loss'] for r in results) / len(results),
                'best_accuracy': max(r['accuracy'] for r in results),
                'results': results
            }

            # 保存聚合结果
            conn.execute(text("""
                INSERT INTO experiment_results
                (experiment_id, aggregated_data, created_at)
                VALUES (:exp_id, :data, NOW())
            """), {
                "exp_id": experiment_id,
                "data": json.dumps(aggregated)
            })

            return aggregated

6. 本章自测

练习1:事务设计

设计一个电商订单系统的事务流程,要求: 1. 扣减库存 2. 创建订单 3. 扣减用户余额 4. 记录交易日志

要求考虑并发安全和性能,选择合适的锁策略。

练习2:隔离级别选择

为以下场景选择合适的隔离级别,并说明理由:

  1. 实时数据看板(读取最新的汇总数据)
  2. 银行转账系统
  3. 模型训练日志记录
  4. 商品库存查询

练习3:死锁排查

以下代码存在死锁风险,请找出并修复:

Python
def transfer(a_id, b_id, amount):
    with engine.begin() as conn:
        # 获取账户A
        a = conn.execute(
            "SELECT * FROM accounts WHERE id = %s FOR UPDATE", (a_id,)
        ).fetchone()

        # 获取账户B
        b = conn.execute(
            "SELECT * FROM accounts WHERE id = %s FOR UPDATE", (b_id,)
        ).fetchone()

        # 转账逻辑...

练习4:AI场景事务

设计一个A/B测试系统的事务方案: - 需要记录实验分组 - 记录用户行为 - 实时计算指标 - 支持实验停止和数据归档


7. 本章小结

核心知识点

  1. ACID特性:原子性、一致性、隔离性、持久性
  2. 隔离级别:READ UNCOMMITTED → READ COMMITTED → REPEATABLE READ → SERIALIZABLE
  3. 并发问题:脏读、不可重复读、幻读
  4. 锁机制
  5. 悲观锁:SELECT FOR UPDATE,适合高冲突场景
  6. 乐观锁:版本号机制,适合读多写少
  7. 死锁处理:统一加锁顺序、设置超时、自动重试
  8. AI场景:训练事务管理、数据流水线、分布式协调

事务设计检查清单

Markdown
□ 明确事务边界,避免事务过长
□ 选择合适的隔离级别
□ 高并发场景选择合适的锁策略
□ 按固定顺序访问资源防止死锁
□ 设置合理的锁等待超时
□ 实现死锁检测和自动重试机制
□ 区分关键操作和非关键操作的事务要求
□ 定期监控死锁和锁等待情况

下一步

完成本章学习后,继续学习 第09章:数据库与AI应用,了解数据库在AI/ML工作流中的高级应用场景,包括特征存储、向量数据库、MLOps数据管理等。


参考资源: - MySQL事务与锁 - PostgreSQL事务隔离 - 数据库事务设计模式