跳转至

🔧 模型部署与服务化

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:掌握模型序列化、推理服务框架、Docker容器化、Kubernetes部署、A/B测试与性能优化 预计时长:12-15小时 前置知识:PyTorch/TensorFlow基础、Docker基础、REST API概念


📋 本章概览

模型部署是将训练好的模型转化为实际业务价值的关键环节。本章覆盖从模型序列化到生产级推理服务的完整链路。

Text Only
模型部署全链路:
┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐
│ 模型导出  │───▶│ 推理服务  │───▶│ 容器化   │───▶│ 编排部署  │
│ ONNX/TS  │    │ Triton   │    │ Docker   │    │ K8s/KServe│
└──────────┘    └──────────┘    └──────────┘    └──────────┘
     │                │                │               │
     ▼                ▼                ▼               ▼
  序列化优化      批处理/并发      镜像优化        弹性伸缩
  计算图优化      模型预热        多阶段构建      A/B测试

上图展示了模型部署的完整链路,包括模型序列化导出、推理服务搭建、Docker容器化和Kubernetes编排部署等关键阶段。每个阶段都有对应的优化策略,确保模型能够高效、可靠地交付到生产环境。


一、模型序列化

1.1 ONNX(Open Neural Network Exchange)

ONNX是模型的通用中间表示,支持跨框架部署。

Python
import torch
import torch.nn as nn
import numpy as np

# 定义模型
class TextEncoder(nn.Module):
    def __init__(self, vocab_size=30000, embed_dim=256, hidden_dim=512, num_classes=10):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids):
        embeds = self.embedding(input_ids)
        lstm_out, (hidden, _) = self.lstm(embeds)
        # 拼接前向和后向最终隐状态
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)  # 负索引:从末尾倒数访问元素
        logits = self.classifier(hidden)
        return logits

model = TextEncoder()
model.eval()

# ===== 导出ONNX =====
dummy_input = torch.randint(0, 30000, (1, 128))  # [batch_size, seq_len]

torch.onnx.export(
    model,
    dummy_input,
    "text_encoder.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,           # 常量折叠优化
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={                       # 动态维度
        "input_ids": {0: "batch_size", 1: "seq_len"},
        "logits": {0: "batch_size"}
    }
)

print("ONNX model exported successfully!")

# ===== 验证ONNX模型 =====
import onnx
onnx_model = onnx.load("text_encoder.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model validation passed!")

# ===== ONNX Runtime推理 =====
import onnxruntime as ort

session = ort.InferenceSession(
    "text_encoder.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

# 推理
input_data = np.random.randint(0, 30000, (2, 128)).astype(np.int64)
outputs = session.run(None, {"input_ids": input_data})
print(f"Output shape: {outputs[0].shape}")  # [2, 10]

1.2 ONNX优化

Python
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.transformers import optimizer

# ===== 1. 动态量化 =====
quantize_dynamic(
    model_input="text_encoder.onnx",
    model_output="text_encoder_int8.onnx",
    weight_type=QuantType.QInt8
)

# ===== 2. 图优化(Transformer模型专用)=====
# optimized_model = optimizer.optimize_model(
#     "bert_model.onnx",
#     model_type="bert",
#     num_heads=12,
#     hidden_size=768
# )
# optimized_model.save_model_to_file("bert_optimized.onnx")

# ===== 3. 性能基准测试 =====
import time

def benchmark_onnx(model_path: str, input_data: dict, num_runs: int = 100):
    """ONNX模型性能基准测试"""
    session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])

    # 预热
    for _ in range(10):
        session.run(None, input_data)

    # 基准测试
    latencies = []
    for _ in range(num_runs):
        start = time.perf_counter()
        session.run(None, input_data)
        latencies.append((time.perf_counter() - start) * 1000)

    print(f"Model: {model_path}")
    print(f"  Mean latency: {np.mean(latencies):.2f} ms")
    print(f"  P50 latency:  {np.percentile(latencies, 50):.2f} ms")
    print(f"  P95 latency:  {np.percentile(latencies, 95):.2f} ms")
    print(f"  P99 latency:  {np.percentile(latencies, 99):.2f} ms")

    return latencies

# test_input = {"input_ids": np.random.randint(0, 30000, (1, 128)).astype(np.int64)}
# benchmark_onnx("text_encoder.onnx", test_input)
# benchmark_onnx("text_encoder_int8.onnx", test_input)

1.3 TorchScript

Python
import torch

model = TextEncoder()
model.eval()

# ===== 方式一:Tracing(推荐简单模型)=====
dummy_input = torch.randint(0, 30000, (1, 128))
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save("text_encoder_traced.pt")

# ===== 方式二:Scripting(支持动态控制流)=====
class DynamicModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        # 包含条件分支,必须用script
        if x.shape[0] > 1:
            return self.linear(x).mean(dim=0)
        else:
            return self.linear(x).squeeze(0)

dynamic_model = DynamicModel()
scripted_model = torch.jit.script(dynamic_model)
scripted_model.save("dynamic_model_scripted.pt")

# 加载并推理
loaded_model = torch.jit.load("text_encoder_traced.pt")
with torch.no_grad():
    output = loaded_model(dummy_input)
    print(f"TorchScript output shape: {output.shape}")

1.4 SavedModel(TensorFlow)

Python
import tensorflow as tf

# 定义模型
class TFClassifier(tf.keras.Model):
    def __init__(self, num_classes=10):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(256, activation='relu')
        self.dropout = tf.keras.layers.Dropout(0.3)
        self.dense2 = tf.keras.layers.Dense(num_classes)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)

# 构建并保存
model = TFClassifier()
model(tf.random.normal((1, 768)))  # 初始化权重

# SavedModel格式
tf.saved_model.save(model, "saved_model/classifier")

# 带签名保存(推荐)
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 768], dtype=tf.float32)])
def serve(inputs):
    return {"predictions": model(inputs, training=False)}

tf.saved_model.save(model, "saved_model/classifier_v2", signatures={"serving_default": serve})

# 加载推理
loaded = tf.saved_model.load("saved_model/classifier_v2")
result = loaded.signatures["serving_default"](tf.random.normal((2, 768)))
print(f"TF output shape: {result['predictions'].shape}")

二、推理服务框架

2.1 FastAPI快速部署

Python
"""
简单的模型推理服务
运行:uvicorn serve:app --host 0.0.0.0 --port 8000
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel  # Pydantic数据验证模型
import numpy as np
import onnxruntime as ort
import time
from contextlib import asynccontextmanager

# 全局模型引用
model_session = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用生命周期管理:启动时加载模型"""
    global model_session
    model_session = ort.InferenceSession(
        "text_encoder.onnx",
        providers=["CPUExecutionProvider"]
    )
    print("Model loaded successfully!")
    yield
    print("Shutting down...")

app = FastAPI(title="Text Classification API", version="1.0", lifespan=lifespan)

class PredictRequest(BaseModel):
    input_ids: list[list[int]]  # [batch_size, seq_len]

class PredictResponse(BaseModel):
    predictions: list[int]
    probabilities: list[list[float]]
    latency_ms: float

@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    """模型推理接口"""
    try:  # try/except捕获异常
        start = time.perf_counter()

        input_array = np.array(request.input_ids, dtype=np.int64)
        outputs = model_session.run(None, {"input_ids": input_array})
        logits = outputs[0]

        # Softmax概率
        exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
        probs = exp_logits / exp_logits.sum(axis=-1, keepdims=True)

        predictions = np.argmax(logits, axis=-1).tolist()

        latency = (time.perf_counter() - start) * 1000

        return PredictResponse(
            predictions=predictions,
            probabilities=probs.tolist(),
            latency_ms=round(latency, 2)
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health():
    """健康检查"""
    return {"status": "healthy", "model_loaded": model_session is not None}

2.2 Triton Inference Server

Triton是NVIDIA开发的高性能推理服务器,支持多框架、动态批处理、模型集成。

Text Only
# Triton模型仓库目录结构
model_repository/
├── text_encoder/
│   ├── config.pbtxt      # 模型配置
│   ├── 1/                 # 版本1
│   │   └── model.onnx
│   └── 2/                 # 版本2
│       └── model.onnx
└── ensemble_model/
    ├── config.pbtxt
    └── 1/
Protocol Buffer
# config.pbtxt - Triton模型配置
name: "text_encoder"
platform: "onnxruntime_onnx"
max_batch_size: 64

input [
  {
    name: "input_ids"
    data_type: TYPE_INT64
    dims: [ -1 ]  # 动态序列长度
  }
]

output [
  {
    name: "logits"
    data_type: TYPE_FP32
    dims: [ 10 ]
  }
]

# 动态批处理配置
dynamic_batching {
  preferred_batch_size: [ 8, 16, 32 ]
  max_queue_delay_microseconds: 100
}

# 实例组配置
instance_group [
  {
    count: 2
    kind: KIND_GPU
    gpus: [ 0 ]
  }
]

# 模型优化
optimization {
  execution_accelerators {
    gpu_execution_accelerator: [ {
      name: "tensorrt"
      parameters { key: "precision_mode" value: "FP16" }
      parameters { key: "max_workspace_size_bytes" value: "1073741824" }
    }]
  }
}
Python
# Triton客户端调用
import tritonclient.http as httpclient
import numpy as np

def triton_inference(input_ids: np.ndarray):
    """调用Triton服务"""
    client = httpclient.InferenceServerClient(url="localhost:8000")

    # 检查服务状态
    if not client.is_server_ready():
        raise RuntimeError("Triton server not ready")

    # 构建输入
    inputs = [
        httpclient.InferInput("input_ids", input_ids.shape, "INT64")
    ]
    inputs[0].set_data_from_numpy(input_ids)

    # 构建输出
    outputs = [
        httpclient.InferRequestedOutput("logits")
    ]

    # 推理
    result = client.infer(
        model_name="text_encoder",
        model_version="1",
        inputs=inputs,
        outputs=outputs
    )

    logits = result.as_numpy("logits")
    return logits

# 使用
# input_data = np.random.randint(0, 30000, (4, 128)).astype(np.int64)
# result = triton_inference(input_data)
# print(f"Triton result shape: {result.shape}")
Bash
# 启动Triton服务
# docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \
#   -v $(pwd)/model_repository:/models \
#   nvcr.io/nvidia/tritonserver:24.01-py3 \
#   tritonserver --model-repository=/models

2.3 TorchServe

Python
# 1. 创建模型处理器(handler.py)
"""
自定义TorchServe Handler
"""
import json
import torch
import numpy as np
from ts.torch_handler.base_handler import BaseHandler

class TextClassificationHandler(BaseHandler):

    def __init__(self):
        super().__init__()
        self.initialized = False

    def initialize(self, context):
        """模型初始化"""
        properties = context.system_properties
        model_dir = properties.get("model_dir")

        # 加载模型
        self.model = torch.jit.load(f"{model_dir}/model.pt")
        self.model.eval()

        # 加载标签映射
        with open(f"{model_dir}/labels.json", "r") as f:  # with自动管理资源,确保文件正确关闭
            self.labels = json.load(f)  # json.loads将JSON字符串转为Python对象

        self.initialized = True

    def preprocess(self, data):
        """预处理输入"""
        inputs = []
        for row in data:
            input_data = row.get("data") or row.get("body")
            if isinstance(input_data, (bytes, bytearray)):  # isinstance检查对象类型
                input_data = input_data.decode("utf-8")
            if isinstance(input_data, str):
                input_data = json.loads(input_data)
            inputs.append(input_data["input_ids"])
        return torch.tensor(inputs, dtype=torch.long)

    def inference(self, input_batch):
        """模型推理"""
        with torch.no_grad():
            logits = self.model(input_batch)
            probs = torch.softmax(logits, dim=-1)
            predictions = torch.argmax(probs, dim=-1)
        return predictions, probs

    def postprocess(self, inference_output):
        """后处理输出"""
        predictions, probs = inference_output
        results = []
        for pred, prob in zip(predictions, probs):  # zip并行遍历多个可迭代对象
            results.append({
                "label": self.labels[pred.item()],
                "confidence": prob[pred].item(),
                "all_probabilities": {
                    self.labels[i]: p.item()
                    for i, p in enumerate(prob)  # enumerate同时获取索引和值
                }
            })
        return results
Bash
# 2. 打包模型归档(MAR)
# torch-model-archiver --model-name text_classifier \
#   --version 1.0 \
#   --serialized-file model.pt \
#   --handler handler.py \
#   --extra-files labels.json \
#   --export-path model_store/

# 3. 启动TorchServe
# torchserve --start --model-store model_store/ \
#   --models text_classifier=text_classifier.mar \
#   --ts-config config.properties

三、Docker容器化部署

3.1 推理服务Dockerfile

Docker
# ===== Multi-stage build =====
# Stage 1: 构建阶段
FROM python:3.11-slim AS builder

WORKDIR /build
COPY requirements.txt .
RUN pip install --no-cache-dir --prefix=/install -r requirements.txt

# Stage 2: 运行阶段
FROM python:3.11-slim AS runtime

# 安装必要的系统依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
    curl && \
    rm -rf /var/lib/apt/lists/*

# 复制Python包
COPY --from=builder /install /usr/local

# 创建非root用户
RUN useradd -m -u 1000 appuser

WORKDIR /app

# 复制模型和代码
COPY --chown=appuser:appuser model/ ./model/
COPY --chown=appuser:appuser serve.py .
COPY --chown=appuser:appuser config.yaml .

USER appuser

# 环境变量
ENV MODEL_PATH=/app/model/text_encoder.onnx
ENV WORKERS=4
ENV PORT=8000

# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
    CMD curl -f http://localhost:${PORT}/health || exit 1

EXPOSE ${PORT}

# 使用gunicorn+uvicorn运行
CMD ["sh", "-c", "gunicorn serve:app -w ${WORKERS} -k uvicorn.workers.UvicornWorker -b 0.0.0.0:${PORT} --timeout 120"]
Text Only
# requirements.txt
fastapi==0.115.0
uvicorn==0.30.0
gunicorn==22.0.0
onnxruntime==1.18.0
numpy==1.26.4
pydantic==2.7.0

3.2 GPU推理Dockerfile

Docker
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04  # FROM指定基础镜像

# 安装Python和系统依赖
RUN apt-get update && apt-get install -y --no-install-recommends \  # RUN在构建时执行命令
    python3.11 python3.11-venv python3-pip curl && \
    rm -rf /var/lib/apt/lists/*

WORKDIR /app  # WORKDIR设置工作目录

# 安装Python依赖
COPY requirements-gpu.txt .  # COPY将文件复制到镜像中
RUN pip3 install --no-cache-dir -r requirements-gpu.txt

# 复制代码和模型
COPY serve.py .
COPY model/ ./model/

ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility

HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
    CMD curl -f http://localhost:8000/health || exit 1  # CMD容器启动时执行的默认命令

EXPOSE 8000  # EXPOSE声明容器监听的端口
CMD ["python3", "-m", "uvicorn", "serve:app", "--host", "0.0.0.0", "--port", "8000"]

3.3 Docker Compose编排

YAML
# compose.yaml
# Docker Compose V2 无需 version 字段
services:  # services定义各个服务容器
  inference-server:
    build:
      context: .
      dockerfile: Dockerfile
    ports:
      - "8000:8000"
    environment:
      - MODEL_PATH=/app/model/text_encoder.onnx
      - WORKERS=4
      - LOG_LEVEL=info
    volumes:
      - ./model:/app/model:ro
    deploy:
      resources:
        limits:
          memory: 4G
          cpus: "4.0"
    restart: unless-stopped
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3

  # GPU推理服务
  gpu-inference:
    build:
      context: .
      dockerfile: Dockerfile.gpu
    ports:
      - "8001:8000"
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]
    restart: unless-stopped

  # Nginx负载均衡
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf:ro
    depends_on:
      - inference-server
    restart: unless-stopped

  # Prometheus监控
  prometheus:
    image: prom/prometheus:latest
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
    restart: unless-stopped

四、Kubernetes上的模型服务

4.1 基础K8s部署

YAML
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: text-classifier
  labels:
    app: text-classifier
    version: v1
spec:
  replicas: 3
  selector:
    matchLabels:
      app: text-classifier
  template:
    metadata:
      labels:
        app: text-classifier
        version: v1
    spec:
      containers:
        - name: inference
          image: registry.example.com/text-classifier:v1.2.0
          ports:
            - containerPort: 8000
          resources:
            requests:
              cpu: "2"
              memory: "4Gi"
            limits:
              cpu: "4"
              memory: "8Gi"
              # nvidia.com/gpu: "1"  # GPU推理
          env:
            - name: MODEL_PATH
              value: /app/model/text_encoder.onnx
            - name: WORKERS
              value: "4"
          readinessProbe:
            httpGet:
              path: /health
              port: 8000
            initialDelaySeconds: 30
            periodSeconds: 10
          livenessProbe:
            httpGet:
              path: /health
              port: 8000
            initialDelaySeconds: 60
            periodSeconds: 30
          volumeMounts:
            - name: model-volume
              mountPath: /app/model
              readOnly: true
      volumes:
        - name: model-volume
          persistentVolumeClaim:
            claimName: model-pvc
---  # YAML文档分隔符
# service.yaml
apiVersion: v1
kind: Service
metadata:
  name: text-classifier-svc
spec:
  selector:
    app: text-classifier
  ports:
    - port: 80
      targetPort: 8000
  type: ClusterIP
---
# hpa.yaml - 自动伸缩
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: text-classifier-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: text-classifier
  minReplicas: 2
  maxReplicas: 10
  metrics:
    - type: Resource
      resource:
        name: cpu
        target:
          type: Utilization
          averageUtilization: 70
    - type: Pods
      pods:
        metric:
          name: inference_requests_per_second
        target:
          type: AverageValue
          averageValue: "100"

4.2 KServe(Kubernetes上的ML Serving)

YAML
# kserve-inference-service.yaml
apiVersion: serving.kserve.io/v1beta1
kind: InferenceService
metadata:
  name: text-classifier
  annotations:
    sidecar.istio.io/inject: "true"
spec:
  predictor:
    # 自动伸缩配置
    minReplicas: 1
    maxReplicas: 5
    scaleTarget: 10       # 每个副本处理的并发数
    scaleMetric: concurrency

    # 容器配置
    containers:
      - name: kserve-container
        image: registry.example.com/text-classifier:v1.2.0
        ports:
          - containerPort: 8080
            protocol: TCP
        resources:
          requests:
            cpu: "2"
            memory: "4Gi"
          limits:
            cpu: "4"
            memory: "8Gi"
        env:
          - name: MODEL_NAME
            value: text-classifier

    # 或使用内置框架支持
    # sklearn:
    #   storageUri: gs://my-bucket/models/sklearn/1.0
    # tensorflow:
    #   storageUri: gs://my-bucket/models/tf/1.0
    # pytorch:
    #   storageUri: gs://my-bucket/models/pytorch/1.0

  # Transformer(预处理/后处理)
  transformer:
    containers:
      - name: preprocessor
        image: registry.example.com/text-preprocessor:v1.0
        resources:
          requests:
            cpu: "1"
            memory: "2Gi"
Python
# KServe自定义Predictor
from kserve import Model, ModelServer
import numpy as np
import onnxruntime as ort

class TextClassifierModel(Model):
    def __init__(self, name: str):
        super().__init__(name)  # super()调用父类方法
        self.name = name
        self.model = None
        self.ready = False

    def load(self):
        """加载模型"""
        self.model = ort.InferenceSession(
            "/mnt/models/model.onnx",
            providers=["CPUExecutionProvider"]
        )
        self.ready = True

    def predict(self, payload: dict, headers: dict = None) -> dict:
        """推理"""
        input_ids = np.array(payload["instances"], dtype=np.int64)
        outputs = self.model.run(None, {"input_ids": input_ids})
        predictions = np.argmax(outputs[0], axis=-1).tolist()
        return {"predictions": predictions}

if __name__ == "__main__":
    model = TextClassifierModel("text-classifier")
    model.load()
    ModelServer().start([model])

五、A/B测试与金丝雀发布

5.1 基于流量分割的A/B测试

YAML
# Istio VirtualService实现A/B测试
apiVersion: networking.istio.io/v1beta1  # apiVersion指定K8s API版本
kind: VirtualService  # kind指定资源类型
metadata:
  name: text-classifier-vs
spec:  # spec定义资源的期望状态
  hosts:
    - text-classifier.example.com
  http:
    - match:
        - headers:
            x-model-version:
              exact: "v2"
      route:
        - destination:
            host: text-classifier-v2
            port:
              number: 80
    - route:
        - destination:
            host: text-classifier-v1
            port:
              number: 80
          weight: 90
        - destination:
            host: text-classifier-v2
            port:
              number: 80
          weight: 10

5.2 金丝雀发布控制器

Python
"""
金丝雀发布管理器
逐步将流量从旧版本迁移到新版本
"""
import time
from dataclasses import dataclass

@dataclass
class CanaryConfig:
    """金丝雀发布配置"""
    initial_weight: int = 5       # 初始流量比例 %
    step_weight: int = 10         # 每步增加 %
    max_weight: int = 100         # 最大流量
    step_interval: int = 300      # 每步间隔(秒)
    error_threshold: float = 0.01 # 错误率阈值
    latency_threshold_ms: float = 200  # 延迟阈值
    rollback_on_failure: bool = True

class CanaryDeployment:
    """金丝雀部署管理器"""

    def __init__(self, config: CanaryConfig):
        self.config = config
        self.current_weight = 0
        self.is_active = False

    def start(self, old_version: str, new_version: str):
        """开始金丝雀发布"""
        self.old_version = old_version
        self.new_version = new_version
        self.current_weight = self.config.initial_weight
        self.is_active = True

        print(f"Starting canary: {old_version} -> {new_version}")
        print(f"Initial traffic split: {self.current_weight}% to new version")

        self._update_traffic_split()

    def check_health(self) -> dict:
        """检查新版本的健康状况"""
        # 实际实现中从Prometheus/监控系统获取指标
        metrics = {
            "error_rate": 0.005,       # 模拟:0.5%错误率
            "p99_latency_ms": 150,     # 模拟:150ms p99延迟
            "accuracy": 0.92,          # 模拟:92%准确率
        }
        return metrics

    def should_proceed(self, metrics: dict) -> bool:
        """判断是否可以继续推进"""
        if metrics["error_rate"] > self.config.error_threshold:
            print(f"ERROR: Error rate {metrics['error_rate']:.3f} exceeds threshold")
            return False
        if metrics["p99_latency_ms"] > self.config.latency_threshold_ms:
            print(f"ERROR: Latency {metrics['p99_latency_ms']:.0f}ms exceeds threshold")
            return False
        return True

    def advance(self):
        """推进金丝雀流量"""
        metrics = self.check_health()

        if not self.should_proceed(metrics):
            if self.config.rollback_on_failure:
                self.rollback()
            return False

        self.current_weight = min(
            self.current_weight + self.config.step_weight,
            self.config.max_weight
        )

        print(f"Advancing canary: {self.current_weight}% traffic to {self.new_version}")
        self._update_traffic_split()

        if self.current_weight >= self.config.max_weight:
            self.promote()
            return True

        return True

    def rollback(self):
        """回滚到旧版本"""
        print(f"ROLLBACK: Reverting to {self.old_version}")
        self.current_weight = 0
        self._update_traffic_split()
        self.is_active = False

    def promote(self):
        """完全切换到新版本"""
        print(f"PROMOTED: {self.new_version} is now serving 100% traffic")
        self.is_active = False

    def _update_traffic_split(self):
        """更新流量分配(实际中调用K8s API)"""
        old_weight = 100 - self.current_weight
        print(f"  Traffic: {self.old_version}={old_weight}%, {self.new_version}={self.current_weight}%")

# 使用示例
config = CanaryConfig(initial_weight=5, step_weight=15, step_interval=60)
canary = CanaryDeployment(config)
canary.start("v1.0", "v2.0")

# 模拟逐步推进
for step in range(8):
    print(f"\n--- Step {step + 1} ---")
    if not canary.advance():
        break
    if not canary.is_active:
        break

六、API设计与性能优化

6.1 批处理推理

Python
import asyncio
import time
import numpy as np
from collections import deque
from dataclasses import dataclass, field
from threading import Lock  # 线程池/多线程:并发执行任务

@dataclass  # 自动生成__init__等方法
class InferenceRequest:
    """推理请求"""
    input_ids: np.ndarray
    future: asyncio.Future = field(default=None)
    timestamp: float = field(default_factory=time.time)

class DynamicBatcher:
    """
    动态批处理:将多个请求合并为批次,提高GPU利用率
    """

    def __init__(
        self,
        model_session,
        max_batch_size: int = 32,
        max_wait_time_ms: float = 50,
    ):
        self.model = model_session
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time_ms / 1000
        self.queue: deque = deque()
        self.lock = Lock()
        self._running = True

    async def predict(self, input_ids: np.ndarray) -> np.ndarray:
        """提交推理请求,等待批处理结果"""
        loop = asyncio.get_running_loop()
        future = loop.create_future()

        request = InferenceRequest(
            input_ids=input_ids,
            future=future
        )

        with self.lock:
            self.queue.append(request)

        # 如果达到批大小,立即处理
        if len(self.queue) >= self.max_batch_size:
            await self._process_batch()

        return await future

    async def _process_batch(self):
        """处理一个批次"""
        with self.lock:
            batch_size = min(len(self.queue), self.max_batch_size)
            if batch_size == 0:
                return

            requests = [self.queue.popleft() for _ in range(batch_size)]

        # 合并输入
        batch_input = np.concatenate([r.input_ids for r in requests], axis=0)

        # 批量推理
        outputs = self.model.run(None, {"input_ids": batch_input})
        results = outputs[0]

        # 分发结果
        offset = 0
        for req in requests:
            size = req.input_ids.shape[0]
            result = results[offset:offset + size]
            if not req.future.done():
                req.future.set_result(result)
            offset += size

    async def batch_processor(self):
        """后台批处理循环"""
        while self._running:
            if len(self.queue) > 0:
                oldest = self.queue[0].timestamp
                if (time.time() - oldest) >= self.max_wait_time or \
                   len(self.queue) >= self.max_batch_size:
                    await self._process_batch()
            await asyncio.sleep(0.001)  # 1ms检查间隔

6.2 异步推理与流式响应

Python
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
import json

app = FastAPI()

async def async_inference(input_data: dict):  # async def定义异步函数;用await调用
    """异步推理"""
    # 模拟耗时推理
    await asyncio.sleep(0.1)  # await等待异步操作完成
    return {"prediction": 1, "confidence": 0.95}

@app.post("/predict/async")
async def predict_async(request: dict):
    """异步推理接口 - 返回任务ID"""
    task_id = f"task_{id(request)}"
    # 提交后台任务
    asyncio.create_task(process_task(task_id, request))
    return {"task_id": task_id, "status": "processing"}

@app.post("/predict/stream")
async def predict_stream(request: dict):
    """流式推理接口 - 适用于LLM生成"""
    async def generate():
        tokens = ["Hello", " world", "!", " This", " is", " streaming", "."]
        for token in tokens:
            yield f"data: {json.dumps({'token': token})}\n\n"  # yield生成器:惰性产出值,节省内存
            await asyncio.sleep(0.1)
        yield "data: [DONE]\n\n"

    return StreamingResponse(
        generate(),
        media_type="text/event-stream"
    )

async def process_task(task_id: str, request: dict):
    """后台处理任务"""
    result = await async_inference(request)
    # 存储结果到Redis/数据库
    print(f"Task {task_id} completed: {result}")

6.3 性能优化清单

Python
"""
推理性能优化策略汇总
"""

class InferenceOptimizer:
    """推理优化策略集合"""

    @staticmethod  # @staticmethod静态方法,不需要实例
    def model_optimization():
        """模型层面优化"""
        strategies = {
            "量化": "FP32 → FP16/INT8,减少计算量和内存",
            "剪枝": "移除冗余参数,减少模型大小",
            "蒸馏": "大模型指导小模型,保持性能降低成本",
            "ONNX优化": "图优化、常量折叠、算子融合",
            "TensorRT": "NVIDIA GPU专用优化,FP16/INT8加速",
        }
        return strategies

    @staticmethod
    def serving_optimization():
        """服务层面优化"""
        strategies = {
            "动态批处理": "合并请求,提高GPU利用率",
            "模型预热": "启动时预加载模型和运行warmup推理",
            "连接池": "复用HTTP/gRPC连接",
            "异步处理": "非阻塞IO,提高并发能力",
            "缓存": "相同输入缓存结果,减少重复计算",
            "多Worker": "Gunicorn多进程,利用多核CPU",
        }
        return strategies

    @staticmethod
    def infrastructure_optimization():
        """基础设施优化"""
        strategies = {
            "弹性伸缩": "HPA根据负载自动扩缩容",
            "GPU共享": "MIG/MPS多模型共享GPU",
            "负载均衡": "Nginx/Istio分配请求",
            "CDN缓存": "静态资产缓存,减少带宽",
            "就近部署": "多区域部署,降低网络延迟",
        }
        return strategies

# 模型预热示例
def warmup_model(session, input_shape, num_warmup: int = 10):
    """模型预热,消除首次推理的冷启动延迟"""
    import numpy as np

    dummy_input = np.random.randint(0, 1000, input_shape).astype(np.int64)

    for i in range(num_warmup):
        session.run(None, {"input_ids": dummy_input})

    print(f"Model warmup completed ({num_warmup} runs)")

💼 面试常考题

Q1: ONNX和TorchScript有什么区别?什么场景选择哪个?

:ONNX是跨框架中间表示,适合异构环境(不同框架、不同硬件)部署,配合ONNX Runtime/TensorRT有最佳性能。TorchScript是PyTorch原生的序列化格式,支持动态控制流(script模式),适合PyTorch生态内部署。选择依据:需要跨框架或最大性能→ONNX;需要复杂控制流或保持PyTorch生态→TorchScript。

Q2: 如何设计一个高可用的推理服务?

:①多副本部署+负载均衡②健康检查(readiness/liveness probe)③自动伸缩(HPA/KEDA)④优雅关闭(graceful shutdown)⑤Circuit Breaker断路器⑥请求超时和重试机制⑦模型版本管理+快速回滚⑧监控告警(延迟/错误率/吞吐量)。

Q3: 动态批处理(Dynamic Batching)是如何工作的?有什么优缺点?

:动态批处理在服务端收集一段时间窗口内的请求,合并为一个批次发送给GPU计算。优点:提高GPU利用率,批推理比逐条推理快很多。缺点:增加单条请求的延迟(需等待凑批),需调参(最大等待时间、最大批大小)平衡延迟和吞吐。Triton/TorchServe都有内置支持。

Q4: 金丝雀发布和A/B测试的区别是什么?

金丝雀发布是渐进式部署策略,按比例逐步将流量从旧版本迁移到新版本,出问题立即回滚,目标是安全上线。A/B测试是实验方法,将用户分为对照组和实验组,目的是评估新模型的业务效果(CTR/转化率等),需要统计显著性检验。两者可结合:先金丝雀验证稳定性,再A/B测试验证效果。

Q5: 模型部署后如何监控推理服务的健康状态?

:三个层面:①基础设施指标(CPU/GPU利用率、内存、网络)②服务指标(QPS、延迟P50/P95/P99、错误率、超时率)③业务指标(预测分布、置信度分布、输入特征统计)。使用Prometheus采集+Grafana展示+AlertManager告警。

Q6: 如何优化推理延迟?从模型到服务的完整优化链路是什么?

: - 模型级:量化(FP16/INT8)→剪枝→蒸馏→ONNX图优化→TensorRT编译 - 服务级:模型预热→动态批处理→异步IO→连接池→结果缓存 - 基础设施级:GPU型号选择→多副本负载均衡→就近部署→弹性伸缩 - 应用级:输入预处理优化→减少序列化开销→gRPC替代REST

Q7: 如何实现模型的零停机更新(Zero-downtime Deployment)?

:①Rolling Update:K8s滚动更新,逐个替换Pod②Blue-Green:同时运行新旧版本,切换Ingress流量③Canary+自动回滚④KServe的多版本支持,指定流量分割⑤模型热加载:不重启服务,在运行时切换模型文件。

Q8: 设计一个支持多模型的推理服务架构。

:使用Triton或自建网关:①模型注册中心(存储模型元数据、版本、端点)②统一API网关(路由请求到对应模型服务)③模型管理服务(加载/卸载/更新模型)④资源调度(根据负载动态分配GPU/CPU)⑤共享基础设施(日志、监控、认证、限流)⑥模型编排(支持Pipeline,一个请求经过多个模型)。


✅ 学习检查清单

  • 能将PyTorch模型导出为ONNX格式并验证
  • 理解TorchScript的Tracing和Scripting区别
  • 能使用FastAPI搭建推理服务
  • 理解Triton Inference Server的配置和使用
  • 能编写多阶段Docker构建文件
  • 理解K8s Deployment/Service/HPA配置
  • 理解KServe的InferenceService概念
  • 能设计金丝雀发布方案
  • 理解动态批处理的原理和实现
  • 能设计高可用推理服务架构
  • 能回答所有面试题

📌 下一章03-监控与持续优化 — 生产环境中的模型监控与持续迭代